Compare commits

..

7 Commits

Author SHA1 Message Date
Conrad Ludgate
3208220456 parameterize proxy test fixtures 2024-10-29 10:36:45 +00:00
Conrad Ludgate
ea1e84b585 stubgen typechecking 2024-10-29 09:57:22 +00:00
Conrad Ludgate
4afd53c8b2 slight refactor 2024-10-28 17:33:02 +00:00
Conrad Ludgate
4df5b7631c type checking 2024-10-28 17:27:43 +00:00
Conrad Ludgate
6a28a47708 fmt 2024-10-28 15:55:56 +00:00
Conrad Ludgate
3c880f6bb3 fix port allocation 2024-10-28 14:47:55 +00:00
Conrad Ludgate
33751f2805 [auth_broker]: regress test 2024-10-28 13:39:44 +00:00
24 changed files with 1361 additions and 255 deletions

View File

@@ -839,7 +839,6 @@ jobs:
- name: Build vm image
run: |
./vm-builder \
-size=2G \
-spec=compute/vm-image-spec-${{ matrix.version.debian }}.yaml \
-src=neondatabase/compute-node-${{ matrix.version.pg }}:${{ needs.tag.outputs.build-tag }} \
-dst=neondatabase/vm-compute-node-${{ matrix.version.pg }}:${{ needs.tag.outputs.build-tag }}
@@ -1119,6 +1118,7 @@ jobs:
-f deployPgSniRouter=true \
-f deployProxyLink=true \
-f deployPrivatelinkProxy=true \
-f deployLegacyProxyScram=true \
-f deployProxyScram=true \
-f deployProxyAuthBroker=true \
-f branch=main \

View File

@@ -666,7 +666,7 @@ RUN apt-get update && \
#
# Use new version only for v17
# because Release_2024_09_1 has some backward incompatible changes
# https://github.com/rdkit/rdkit/releases/tag/Release_2024_09_1
# https://github.com/rdkit/rdkit/releases/tag/Release_2024_09_1
ENV PATH="/usr/local/pgsql/bin/:/usr/local/pgsql/:$PATH"
RUN case "${PG_VERSION}" in \
"v17") \
@@ -860,14 +860,13 @@ ENV PATH="/home/nonroot/.cargo/bin:/usr/local/pgsql/bin/:$PATH"
USER nonroot
WORKDIR /home/nonroot
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 is not supported yet by pgrx. Quit" && exit 0;; \
esac && \
curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && \
chmod +x rustup-init && \
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
rm rustup-init && \
case "${PG_VERSION}" in \
'v17') \
echo 'v17 is not supported yet by pgrx. Quit' && exit 0;; \
esac && \
cargo install --locked --version 0.11.3 cargo-pgrx && \
/bin/bash -c 'cargo pgrx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
@@ -1042,31 +1041,6 @@ RUN wget https://github.com/pgpartman/pg_partman/archive/refs/tags/v5.1.0.tar.gz
make -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pg_partman.control
#########################################################################################
#
# Layer "pg_mooncake"
# compile pg_mooncake extension
#
#########################################################################################
FROM rust-extensions-build AS pg-mooncake-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PG_MOONCAKE_VERSION=0a7de4c0b5c7b1a5e2175e1c5f4625b97b7346f1
ENV PATH="/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in \
'v14') \
echo "pg_mooncake is not supported on Postgres ${PG_VERSION}" && exit 0;; \
esac && \
git clone --depth 1 --branch neon https://github.com/Mooncake-Labs/pg_mooncake.git pg_mooncake-src && \
cd pg_mooncake-src && \
git checkout "${PG_MOONCAKE_VERSION}" && \
git submodule update --init --depth 1 --recursive && \
make BUILD_TYPE=release -j $(getconf _NPROCESSORS_ONLN) && \
make BUILD_TYPE=release -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pg_mooncake.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -1110,7 +1084,6 @@ COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
COPY --from=pg-anon-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-ivm-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-partman-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-mooncake-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \

45
poetry.lock generated
View File

@@ -1521,6 +1521,21 @@ files = [
[package.dependencies]
six = "*"
[[package]]
name = "jwcrypto"
version = "1.5.6"
description = "Implementation of JOSE Web standards"
optional = false
python-versions = ">= 3.8"
files = [
{file = "jwcrypto-1.5.6-py3-none-any.whl", hash = "sha256:150d2b0ebbdb8f40b77f543fb44ffd2baeff48788be71f67f03566692fd55789"},
{file = "jwcrypto-1.5.6.tar.gz", hash = "sha256:771a87762a0c081ae6166958a954f80848820b2ab066937dc8b8379d65b1b039"},
]
[package.dependencies]
cryptography = ">=3.4"
typing-extensions = ">=4.5.0"
[[package]]
name = "kafka-python"
version = "2.0.2"
@@ -2111,7 +2126,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
@@ -2120,8 +2134,6 @@ files = [
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
@@ -2603,7 +2615,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -2912,6 +2923,20 @@ files = [
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
[[package]]
name = "types-jwcrypto"
version = "1.5.0.20240925"
description = "Typing stubs for jwcrypto"
optional = false
python-versions = ">=3.8"
files = [
{file = "types-jwcrypto-1.5.0.20240925.tar.gz", hash = "sha256:50e17b790378c96239344476c7bd13b52d0c7eeb6d16c2d53723e48cc6bbf4fe"},
{file = "types_jwcrypto-1.5.0.20240925-py3-none-any.whl", hash = "sha256:2d12a2d528240d326075e896aafec7056b9136bf3207fa6ccf3fcb8fbf9e11a1"},
]
[package.dependencies]
cryptography = "*"
[[package]]
name = "types-psutil"
version = "5.9.5.12"
@@ -3159,16 +3184,6 @@ files = [
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"},
{file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"},
{file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},
@@ -3406,4 +3421,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "0f4804119f417edf8e1fbd6d715d2e8d70ad731334fa9570304a2203f83339cf"
content-hash = "ad5c9ee7723359af22bbd7fa41538dcf78913c02e947a13a8f9a87eb3a59039e"

View File

@@ -1,5 +1,5 @@
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info};
use tracing::{info, warn};
use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint};
use crate::auth::{self, AuthFlow};
@@ -21,7 +21,7 @@ pub(crate) async fn authenticate_cleartext(
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
debug!("cleartext auth flow override is enabled, proceeding");
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client
@@ -61,7 +61,7 @@ pub(crate) async fn password_hack_no_authentication(
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
) -> auth::Result<(ComputeUserInfo, Vec<u8>)> {
debug!("project not specified, resorting to the password hack auth flow");
warn!("project not specified, resorting to the password hack auth flow");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
// pause the timer while we communicate with the client

View File

@@ -137,6 +137,9 @@ struct ProxyCliArgs {
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_dynamic_rate_limiter: bool,
/// Endpoint rate limiter max number of requests per second.
///
/// Provided in the form `<Requests Per Second>@<Bucket Duration Size>`.
@@ -612,6 +615,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
and metric-collection-interval must be specified"
),
};
if !args.disable_dynamic_rate_limiter {
bail!("dynamic rate limiter should be disabled");
}
let config::ConcurrencyLockOptions {
shards,

View File

@@ -42,6 +42,9 @@ pytest-repeat = "^0.9.3"
websockets = "^12.0"
clickhouse-connect = "^0.7.16"
kafka-python = "^2.0.2"
jwcrypto = "^1.5.6"
h2 = "^4.1.0"
types-jwcrypto = "^1.5.0.20240925"
[tool.poetry.group.dev.dependencies]
mypy = "==1.3.0"

View File

@@ -21,15 +21,18 @@ use postgres_backend::QueryError;
use pq_proto::BeMessage;
use serde::Deserialize;
use serde::Serialize;
use std::future;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tokio::task;
use tokio::task::JoinHandle;
use tokio::time::{Duration, MissedTickBehavior};
use tokio::time::Duration;
use tokio::time::Instant;
use tracing::*;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
@@ -441,9 +444,9 @@ async fn network_write<IO: AsyncRead + AsyncWrite + Unpin>(
}
}
/// The WAL flush interval. This ensures we periodically flush the WAL and send AppendResponses to
/// walproposer, even when it's writing a steady stream of messages.
const FLUSH_INTERVAL: Duration = Duration::from_secs(1);
// Send keepalive messages to walproposer, to make sure it receives updates
// even when it writes a steady stream of messages.
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
/// Encapsulates a task which takes messages from msg_rx, processes and pushes
/// replies to reply_tx.
@@ -491,76 +494,67 @@ impl WalAcceptor {
async fn run(&mut self) -> anyhow::Result<()> {
let walreceiver_guard = self.tli.get_walreceivers().register(self.conn_id);
// Periodically flush the WAL.
let mut flush_ticker = tokio::time::interval(FLUSH_INTERVAL);
flush_ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
flush_ticker.tick().await; // skip the initial, immediate tick
// After this timestamp we will stop processing AppendRequests and send a response
// to the walproposer. walproposer sends at least one AppendRequest per second,
// we will send keepalives by replying to these requests once per second.
let mut next_keepalive = Instant::now();
// Tracks unflushed appends.
let mut dirty = false;
while let Some(mut next_msg) = self.msg_rx.recv().await {
// Update walreceiver state in shmem for reporting.
if let ProposerAcceptorMessage::Elected(_) = &next_msg {
walreceiver_guard.get().status = WalReceiverStatus::Streaming;
}
loop {
let reply = tokio::select! {
// Process inbound message.
msg = self.msg_rx.recv() => {
// If disconnected, break to flush WAL and return.
let Some(mut msg) = msg else {
break;
};
// Update walreceiver state in shmem for reporting.
if let ProposerAcceptorMessage::Elected(_) = &msg {
walreceiver_guard.get().status = WalReceiverStatus::Streaming;
}
// Don't flush the WAL on every append, only periodically via flush_ticker.
// This batches multiple appends per fsync. If the channel is empty after
// sending the reply, we'll schedule an immediate flush.
if let ProposerAcceptorMessage::AppendRequest(append_request) = msg {
msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
dirty = true;
}
self.tli.process_msg(&msg).await?
}
// While receiving AppendRequests, flush the WAL periodically and respond with an
// AppendResponse to let walproposer know we're still alive.
_ = flush_ticker.tick(), if dirty => {
dirty = false;
self.tli
.process_msg(&ProposerAcceptorMessage::FlushWAL)
.await?
}
// If there are no pending messages, flush the WAL immediately.
let reply_msg = if matches!(next_msg, ProposerAcceptorMessage::AppendRequest(_)) {
// Loop through AppendRequests while available to write as many WAL records as
// possible without fsyncing.
//
// TODO: this should be done via flush_ticker.reset_immediately(), but that's always
// delayed by 1ms due to this bug: https://github.com/tokio-rs/tokio/issues/6866.
_ = future::ready(()), if dirty && self.msg_rx.is_empty() => {
dirty = false;
flush_ticker.reset();
self.tli
.process_msg(&ProposerAcceptorMessage::FlushWAL)
.await?
// Make sure the WAL is flushed before returning, see:
// https://github.com/neondatabase/neon/issues/9259
//
// Note: this will need to be rewritten if we want to read non-AppendRequest messages here.
// Otherwise, we might end up in a situation where we read a message, but don't
// process it.
while let ProposerAcceptorMessage::AppendRequest(append_request) = next_msg {
let noflush_msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
if let Some(reply) = self.tli.process_msg(&noflush_msg).await? {
if self.reply_tx.send(reply).await.is_err() {
break; // disconnected, flush WAL and return on next send/recv
}
}
// get out of this loop if keepalive time is reached
if Instant::now() >= next_keepalive {
break;
}
// continue pulling AppendRequests if available
match self.msg_rx.try_recv() {
Ok(msg) => next_msg = msg,
Err(TryRecvError::Empty) => break,
// on disconnect, flush WAL and return on next send/recv
Err(TryRecvError::Disconnected) => break,
};
}
// flush all written WAL to the disk
self.tli
.process_msg(&ProposerAcceptorMessage::FlushWAL)
.await?
} else {
// process message other than AppendRequest
self.tli.process_msg(&next_msg).await?
};
// Send reply, if any.
if let Some(reply) = reply {
if let Some(reply) = reply_msg {
if self.reply_tx.send(reply).await.is_err() {
break; // disconnected, break to flush WAL and return
return Ok(()); // chan closed, streaming terminated
}
// reset keepalive time
next_keepalive = Instant::now() + KEEPALIVE_INTERVAL;
}
}
// Flush WAL on disconnect, see https://github.com/neondatabase/neon/issues/9259.
if dirty {
self.tli
.process_msg(&ProposerAcceptorMessage::FlushWAL)
.await?;
}
Ok(())
}
}

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
pytest_plugins = (
"fixtures.pg_version",
"fixtures.parametrize",
"fixtures.h2server",
"fixtures.httpserver",
"fixtures.compute_reconfigure",
"fixtures.storage_controller_proxy",

View File

@@ -0,0 +1,198 @@
"""
https://python-hyper.org/projects/hyper-h2/en/stable/asyncio-example.html
auth-broker -> local-proxy needs a h2 connection, so we need a h2 server :)
"""
import asyncio
import collections
import io
import json
from collections.abc import AsyncIterable
import pytest_asyncio
from h2.config import H2Configuration
from h2.connection import H2Connection
from h2.errors import ErrorCodes
from h2.events import (
ConnectionTerminated,
DataReceived,
RemoteSettingsChanged,
RequestReceived,
StreamEnded,
StreamReset,
WindowUpdated,
)
from h2.exceptions import ProtocolError, StreamClosedError
from h2.settings import SettingCodes
RequestData = collections.namedtuple("RequestData", ["headers", "data"])
class H2Server:
def __init__(self, host, port) -> None:
self.host = host
self.port = port
class H2Protocol(asyncio.Protocol):
def __init__(self):
config = H2Configuration(client_side=False, header_encoding="utf-8")
self.conn = H2Connection(config=config)
self.transport = None
self.stream_data = {}
self.flow_control_futures = {}
def connection_made(self, transport: asyncio.Transport): # type: ignore[override]
self.transport = transport
self.conn.initiate_connection()
self.transport.write(self.conn.data_to_send())
def connection_lost(self, _exc):
for future in self.flow_control_futures.values():
future.cancel()
self.flow_control_futures = {}
def data_received(self, data: bytes):
assert self.transport is not None
try:
events = self.conn.receive_data(data)
except ProtocolError:
self.transport.write(self.conn.data_to_send())
self.transport.close()
else:
self.transport.write(self.conn.data_to_send())
for event in events:
if isinstance(event, RequestReceived):
self.request_received(event.headers, event.stream_id)
elif isinstance(event, DataReceived):
self.receive_data(event.data, event.stream_id)
elif isinstance(event, StreamEnded):
self.stream_complete(event.stream_id)
elif isinstance(event, ConnectionTerminated):
self.transport.close()
elif isinstance(event, StreamReset):
self.stream_reset(event.stream_id)
elif isinstance(event, WindowUpdated):
self.window_updated(event.stream_id, event.delta)
elif isinstance(event, RemoteSettingsChanged):
if SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings:
self.window_updated(None, 0)
self.transport.write(self.conn.data_to_send())
def request_received(self, headers: list[tuple[str, str]], stream_id: int):
headers_map = collections.OrderedDict(headers)
# Store off the request data.
request_data = RequestData(headers_map, io.BytesIO())
self.stream_data[stream_id] = request_data
def stream_complete(self, stream_id: int):
"""
When a stream is complete, we can send our response.
"""
try:
request_data = self.stream_data[stream_id]
except KeyError:
# Just return, we probably 405'd this already
return
headers = request_data.headers
body = request_data.data.getvalue().decode("utf-8")
data = json.dumps({"headers": headers, "body": body}, indent=4).encode("utf8")
response_headers = (
(":status", "200"),
("content-type", "application/json"),
("content-length", str(len(data))),
)
self.conn.send_headers(stream_id, response_headers)
asyncio.ensure_future(self.send_data(data, stream_id))
def receive_data(self, data: bytes, stream_id: int):
"""
We've received some data on a stream. If that stream is one we're
expecting data on, save it off. Otherwise, reset the stream.
"""
try:
stream_data = self.stream_data[stream_id]
except KeyError:
self.conn.reset_stream(stream_id, error_code=ErrorCodes.PROTOCOL_ERROR)
else:
stream_data.data.write(data)
def stream_reset(self, stream_id):
"""
A stream reset was sent. Stop sending data.
"""
if stream_id in self.flow_control_futures:
future = self.flow_control_futures.pop(stream_id)
future.cancel()
async def send_data(self, data, stream_id):
"""
Send data according to the flow control rules.
"""
while data:
while self.conn.local_flow_control_window(stream_id) < 1:
try:
await self.wait_for_flow_control(stream_id)
except asyncio.CancelledError:
return
chunk_size = min(
self.conn.local_flow_control_window(stream_id),
len(data),
self.conn.max_outbound_frame_size,
)
try:
self.conn.send_data(
stream_id, data[:chunk_size], end_stream=(chunk_size == len(data))
)
except (StreamClosedError, ProtocolError):
# The stream got closed and we didn't get told. We're done
# here.
break
assert self.transport is not None
self.transport.write(self.conn.data_to_send())
data = data[chunk_size:]
async def wait_for_flow_control(self, stream_id):
"""
Waits for a Future that fires when the flow control window is opened.
"""
f: asyncio.Future[None] = asyncio.Future()
self.flow_control_futures[stream_id] = f
await f
def window_updated(self, stream_id, delta):
"""
A window update frame was received. Unblock some number of flow control
Futures.
"""
if stream_id and stream_id in self.flow_control_futures:
f = self.flow_control_futures.pop(stream_id)
f.set_result(delta)
elif not stream_id:
for f in self.flow_control_futures.values():
f.set_result(delta)
self.flow_control_futures = {}
@pytest_asyncio.fixture(scope="function")
async def http2_echoserver() -> AsyncIterable[H2Server]:
loop = asyncio.get_event_loop()
serve = await loop.create_server(H2Protocol, "127.0.0.1", 0)
(host, port) = serve.sockets[0].getsockname()
asyncio.create_task(serve.wait_closed())
server = H2Server(host, port)
yield server
serve.close()

View File

@@ -35,6 +35,7 @@ import toml
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from jwcrypto import jwk
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
@@ -54,6 +55,7 @@ from fixtures.common_types import (
TimelineId,
)
from fixtures.endpoint.http import EndpointHttpClient
from fixtures.h2server import H2Server
from fixtures.log_helper import log
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
from fixtures.neon_cli import NeonLocalCli, Pagectl
@@ -3093,12 +3095,41 @@ class PSQL:
)
def generate_proxy_tls_certs(common_name: str, key_path: Path, crt_path: Path):
if not key_path.exists():
r = subprocess.run(
[
"openssl",
"req",
"-new",
"-x509",
"-days",
"365",
"-nodes",
"-text",
"-out",
str(crt_path),
"-keyout",
str(key_path),
"-subj",
f"/CN={common_name}",
"-addext",
f"subjectAltName = DNS:{common_name}",
]
)
assert r.returncode == 0
class NeonProxy(PgProtocol):
link_auth_uri: str = "http://dummy-uri"
class AuthBackend(abc.ABC):
"""All auth backends must inherit from this class"""
@property
def default_conn_url(self) -> Optional[str]:
return None
@abc.abstractmethod
def extra_args(self) -> list[str]:
pass
@@ -3112,7 +3143,7 @@ class NeonProxy(PgProtocol):
*["--allow-self-signed-compute", "true"],
]
class ControlPlane(AuthBackend):
class Console(AuthBackend):
def __init__(self, endpoint: str, fixed_rate_limit: Optional[int] = None):
self.endpoint = endpoint
self.fixed_rate_limit = fixed_rate_limit
@@ -3136,6 +3167,21 @@ class NeonProxy(PgProtocol):
]
return args
@dataclass(frozen=True)
class Postgres(AuthBackend):
pg_conn_url: str
@property
def default_conn_url(self) -> Optional[str]:
return self.pg_conn_url
def extra_args(self) -> list[str]:
return [
# Postgres auth backend params
*["--auth-backend", "postgres"],
*["--auth-endpoint", self.pg_conn_url],
]
def __init__(
self,
neon_binpath: Path,
@@ -3150,7 +3196,7 @@ class NeonProxy(PgProtocol):
):
host = "127.0.0.1"
domain = "proxy.localtest.me" # resolves to 127.0.0.1
super().__init__(host=domain, port=proxy_port)
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
self.domain = domain
self.host = host
@@ -3172,29 +3218,7 @@ class NeonProxy(PgProtocol):
# generate key of it doesn't exist
crt_path = self.test_output_dir / "proxy.crt"
key_path = self.test_output_dir / "proxy.key"
if not key_path.exists():
r = subprocess.run(
[
"openssl",
"req",
"-new",
"-x509",
"-days",
"365",
"-nodes",
"-text",
"-out",
str(crt_path),
"-keyout",
str(key_path),
"-subj",
"/CN=*.localtest.me",
"-addext",
"subjectAltName = DNS:*.localtest.me",
]
)
assert r.returncode == 0
generate_proxy_tls_certs("*.localtest.me", key_path, crt_path)
args = [
str(self.neon_binpath / "proxy"),
@@ -3374,6 +3398,125 @@ class NeonProxy(PgProtocol):
assert out == "ok"
class NeonAuthBroker:
class ControlPlane:
def __init__(self, endpoint: str):
self.endpoint = endpoint
def extra_args(self) -> list[str]:
args = [
*["--auth-backend", "console"],
*["--auth-endpoint", self.endpoint],
]
return args
def __init__(
self,
neon_binpath: Path,
test_output_dir: Path,
http_port: int,
mgmt_port: int,
external_http_port: int,
auth_backend: NeonAuthBroker.ControlPlane,
):
self.domain = "apiauth.localtest.me" # resolves to 127.0.0.1
self.host = "127.0.0.1"
self.http_port = http_port
self.external_http_port = external_http_port
self.neon_binpath = neon_binpath
self.test_output_dir = test_output_dir
self.mgmt_port = mgmt_port
self.auth_backend = auth_backend
self.http_timeout_seconds = 15
self._popen: Optional[subprocess.Popen[bytes]] = None
def start(self) -> NeonAuthBroker:
assert self._popen is None
# generate key of it doesn't exist
crt_path = self.test_output_dir / "proxy.crt"
key_path = self.test_output_dir / "proxy.key"
generate_proxy_tls_certs("apiauth.localtest.me", key_path, crt_path)
args = [
str(self.neon_binpath / "proxy"),
*["--http", f"{self.host}:{self.http_port}"],
*["--mgmt", f"{self.host}:{self.mgmt_port}"],
*["--wss", f"{self.host}:{self.external_http_port}"],
*["-c", str(crt_path)],
*["-k", str(key_path)],
*["--sql-over-http-pool-opt-in", "false"],
*["--is-auth-broker", "true"],
*self.auth_backend.extra_args(),
]
logfile = open(self.test_output_dir / "proxy.log", "w")
self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile)
self._wait_until_ready()
return self
# Sends SIGTERM to the proxy if it has been started
def terminate(self):
if self._popen:
self._popen.terminate()
# Waits for proxy to exit if it has been opened with a default timeout of
# two seconds. Raises subprocess.TimeoutExpired if the proxy does not exit in time.
def wait_for_exit(self, timeout=2):
if self._popen:
self._popen.wait(timeout=timeout)
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10)
def _wait_until_ready(self):
assert (
self._popen and self._popen.poll() is None
), "Proxy exited unexpectedly. Check test log."
requests.get(f"http://{self.host}:{self.http_port}/v1/status")
async def query(self, query, args, **kwargs):
user = kwargs["user"]
token = kwargs["token"]
expected_code = kwargs.get("expected_code")
log.info(f"Executing http query: {query}")
connstr = f"postgresql://{user}@{self.domain}/postgres"
async with httpx.AsyncClient(verify=str(self.test_output_dir / "proxy.crt")) as client:
response = await client.post(
f"https://{self.domain}:{self.external_http_port}/sql",
json={"query": query, "params": args},
headers={
"Neon-Connection-String": connstr,
"Authorization": f"Bearer {token}",
},
)
if expected_code is not None:
assert response.status_code == expected_code, f"response: {response.json()}"
return response.json()
def get_metrics(self) -> str:
request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics")
return request_result.text
def __enter__(self) -> NeonAuthBroker:
return self
def __exit__(
self,
_exc_type: Optional[type[BaseException]],
_exc_value: Optional[BaseException],
_traceback: Optional[TracebackType],
):
if self._popen is not None:
self._popen.terminate()
try:
self._popen.wait(timeout=5)
except subprocess.TimeoutExpired:
log.warning("failed to gracefully terminate proxy; killing")
self._popen.kill()
@pytest.fixture(scope="function")
def link_proxy(
port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path
@@ -3404,39 +3547,20 @@ def static_proxy(
port_distributor: PortDistributor,
neon_binpath: Path,
test_output_dir: Path,
httpserver: HTTPServer,
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres and a mocked cplane HTTP API."""
"""Neon proxy that routes directly to vanilla postgres."""
port = vanilla_pg.default_options["port"]
host = vanilla_pg.default_options["host"]
dbname = vanilla_pg.default_options["dbname"]
auth_endpoint = f"postgres://proxy:password@{host}:{port}/{dbname}"
# For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql`
vanilla_pg.start()
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
[(rolpassword,)] = vanilla_pg.safe_psql(
"select rolpassword from pg_catalog.pg_authid where rolname = 'proxy'"
)
# return local postgres addr on ProxyWakeCompute.
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
{
"address": f"{host}:{port}",
"aux": {
"endpoint_id": "ep-foo-bar-1234",
"branch_id": "br-foo-bar",
"project_id": "foo-bar",
},
}
)
# return local postgres addr on ProxyWakeCompute.
httpserver.expect_request("/cplane/proxy_get_role_secret").respond_with_json(
{
"role_secret": rolpassword,
"allowed_ips": None,
"project_id": "foo-bar",
}
vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS neon_control_plane")
vanilla_pg.safe_psql(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
)
proxy_port = port_distributor.get_port()
@@ -3451,12 +3575,92 @@ def static_proxy(
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.ControlPlane(httpserver.url_for("/cplane")),
auth_backend=NeonProxy.Postgres(auth_endpoint),
) as proxy:
proxy.default_options["user"] = "proxy"
proxy.default_options["password"] = "password"
proxy.default_options["dbname"] = dbname
proxy.start()
yield proxy
@pytest.fixture(scope="function")
def neon_authorize_jwk() -> jwk.JWK:
kid = str(uuid.uuid4())
key = jwk.JWK.generate(kty="RSA", size=2048, alg="RS256", use="sig", kid=kid)
assert isinstance(key, jwk.JWK)
return key
@pytest.fixture(scope="function")
def cplane_endpoint_jwks(
httpserver: HTTPServer,
neon_authorize_jwk: jwk.JWK,
role_names: list[str],
audience: str | None,
) -> jwk.JWK:
# return static fixture jwks.
jwk = neon_authorize_jwk.export_public(as_dict=True)
httpserver.expect_request("/authorize/jwks.json").respond_with_json({"keys": [jwk]})
id = str(uuid.uuid4())
# return jwks mock addr on GetEndpointJwks
httpserver.expect_request(re.compile("^/cplane/endpoints/.+/jwks$")).respond_with_json(
{
"jwks": [
{
"id": id,
"jwks_url": httpserver.url_for("/authorize/jwks.json"),
"provider_name": "test",
"jwt_audience": audience,
"role_names": role_names,
}
]
}
)
return neon_authorize_jwk
@pytest.fixture(scope="function")
def cplane_wake_compute_local_proxy(
httpserver: HTTPServer,
http2_echoserver: H2Server,
) -> None:
local_proxy_addr = f"{http2_echoserver.host}:{http2_echoserver.port}"
# return local_proxy addr on ProxyWakeCompute.
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
{
"address": local_proxy_addr,
"aux": {
"endpoint_id": "ep-foo-bar-1234",
"branch_id": "br-foo-bar",
"project_id": "foo-bar",
},
}
)
@pytest.fixture(scope="function")
def static_auth_broker(
port_distributor: PortDistributor,
neon_binpath: Path,
test_output_dir: Path,
httpserver: HTTPServer,
) -> Iterable[NeonAuthBroker]:
"""Neon Auth Broker that routes to a mocked local_proxy and a mocked cplane HTTP API."""
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
with NeonAuthBroker(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonAuthBroker.ControlPlane(httpserver.url_for("/cplane")),
) as proxy:
proxy.start()
yield proxy

View File

@@ -0,0 +1,114 @@
import json
import pytest
from fixtures.neon_fixtures import NeonAuthBroker
from jwcrypto import jwk, jwt
@pytest.mark.asyncio
@pytest.mark.parametrize("role_names", [["anonymous", "authenticated"]])
@pytest.mark.parametrize("audience", [None])
async def test_auth_broker_happy(
static_auth_broker: NeonAuthBroker,
cplane_wake_compute_local_proxy: None,
cplane_endpoint_jwks: jwk.JWK,
):
"""
Signs a JWT and uses it to authorize a query to local_proxy.
"""
token = jwt.JWT(
header={"kid": cplane_endpoint_jwks.key_id, "alg": "RS256"}, claims={"sub": "user1"}
)
token.make_signed_token(cplane_endpoint_jwks)
res = await static_auth_broker.query(
"foo", ["arg1"], user="anonymous", token=token.serialize(), expected_code=200
)
# local proxy mock just echos back the request
# check that we forward the correct data
assert (
res["headers"]["authorization"] == f"Bearer {token.serialize()}"
), "JWT should be forwarded"
assert (
"anonymous" in res["headers"]["neon-connection-string"]
), "conn string should be forwarded"
assert json.loads(res["body"]) == {
"query": "foo",
"params": ["arg1"],
}, "Query body should be forwarded"
@pytest.mark.asyncio
@pytest.mark.parametrize("role_names", [["anonymous", "authenticated"]])
@pytest.mark.parametrize("audience", [None])
async def test_auth_broker_incorrect_role(
static_auth_broker: NeonAuthBroker,
cplane_wake_compute_local_proxy: None,
cplane_endpoint_jwks: jwk.JWK,
):
"""
Connects to auth broker with the wrong role associated with the JWKs
"""
token = jwt.JWT(
header={"kid": cplane_endpoint_jwks.key_id, "alg": "RS256"}, claims={"sub": "user1"}
)
token.make_signed_token(cplane_endpoint_jwks)
res = await static_auth_broker.query(
"foo", ["arg1"], user="wrong_role", token=token.serialize(), expected_code=400
)
# if the user is wrong, we announce that the jwk was not found.
assert "jwk not found" in res["message"]
@pytest.mark.asyncio
@pytest.mark.parametrize("role_names", [["anonymous", "authenticated"]])
@pytest.mark.parametrize("audience", ["neon"])
async def test_auth_broker_incorrect_aud(
static_auth_broker: NeonAuthBroker,
cplane_wake_compute_local_proxy: None,
cplane_endpoint_jwks: jwk.JWK,
):
"""
Connects to auth broker with the wrong audience associated with the JWKs
"""
token = jwt.JWT(
header={"kid": cplane_endpoint_jwks.key_id, "alg": "RS256"},
claims={"sub": "user1", "aud": "wrong_aud"},
)
token.make_signed_token(cplane_endpoint_jwks)
res = await static_auth_broker.query(
"foo", ["arg1"], user="anonymous", token=token.serialize(), expected_code=400
)
assert "invalid JWT token audience" in res["message"]
@pytest.mark.asyncio
@pytest.mark.parametrize("role_names", [["anonymous", "authenticated"]])
@pytest.mark.parametrize("audience", ["neon"])
async def test_auth_broker_missing_aud(
static_auth_broker: NeonAuthBroker,
cplane_wake_compute_local_proxy: None,
cplane_endpoint_jwks: jwk.JWK,
):
"""
Connects to auth broker with no audience
"""
token = jwt.JWT(
header={"kid": cplane_endpoint_jwks.key_id, "alg": "RS256"},
claims={"sub": "user1"},
)
token.make_signed_token(cplane_endpoint_jwks)
res = await static_auth_broker.query(
"foo", ["arg1"], user="anonymous", token=token.serialize(), expected_code=400
)
assert "invalid JWT token audience" in res["message"]

View File

@@ -6,27 +6,20 @@ from fixtures.neon_fixtures import (
NeonProxy,
VanillaPostgres,
)
from pytest_httpserver import HTTPServer
TABLE_NAME = "neon_control_plane.endpoints"
def test_proxy_psql_not_allowed_ips(
static_proxy: NeonProxy,
vanilla_pg: VanillaPostgres,
httpserver: HTTPServer,
):
[(rolpassword,)] = vanilla_pg.safe_psql(
"select rolpassword from pg_catalog.pg_authid where rolname = 'proxy'"
)
# Proxy uses the same logic for psql and websockets.
@pytest.mark.asyncio
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
# Shouldn't be able to connect to this project
httpserver.expect_request("/cplane/proxy_get_role_secret").respond_with_json(
{
"role_secret": rolpassword,
"allowed_ips": ["8.8.8.8"],
"project_id": "foo-bar",
}
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
)
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
)
def check_cannot_connect(**kwargs):
@@ -44,25 +37,6 @@ def test_proxy_psql_not_allowed_ips(
# with SNI
check_cannot_connect(query="select 1", host="private-project.localtest.me")
def test_proxy_psql_allowed_ips(
static_proxy: NeonProxy,
vanilla_pg: VanillaPostgres,
httpserver: HTTPServer,
):
[(rolpassword,)] = vanilla_pg.safe_psql(
"select rolpassword from pg_catalog.pg_authid where rolname = 'proxy'"
)
# Should be able to connect to this project
httpserver.expect_request("/cplane/proxy_get_role_secret").respond_with_json(
{
"role_secret": rolpassword,
"allowed_ips": ["::1", "127.0.0.1"],
"project_id": "foo-bar",
}
)
# no SNI, deprecated `options=project` syntax (before we had several endpoint in project)
out = static_proxy.safe_psql(query="select 1", sslsni=0, options="project=generic-project")
assert out[0][0] == 1
@@ -76,61 +50,27 @@ def test_proxy_psql_allowed_ips(
assert out[0][0] == 1
def test_proxy_http_not_allowed_ips(
static_proxy: NeonProxy,
vanilla_pg: VanillaPostgres,
httpserver: HTTPServer,
):
vanilla_pg.safe_psql("create user http_auth with password 'http' superuser")
@pytest.mark.asyncio
async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
[(rolpassword,)] = vanilla_pg.safe_psql(
"select rolpassword from pg_catalog.pg_authid where rolname = 'http_auth'"
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
)
httpserver.expect_oneshot_request("/cplane/proxy_get_role_secret").respond_with_json(
{
"role_secret": rolpassword,
"allowed_ips": ["8.8.8.8"],
"project_id": "foo-bar",
}
)
with httpserver.wait() as waiting:
def query(status: int, query: str, *args):
static_proxy.http_query(
"select 1;",
[],
query,
args,
user="http_auth",
password="http",
expected_code=400,
expected_code=status,
)
assert waiting.result
def test_proxy_http_allowed_ips(
static_proxy: NeonProxy,
vanilla_pg: VanillaPostgres,
httpserver: HTTPServer,
):
vanilla_pg.safe_psql("create user http_auth with password 'http' superuser")
[(rolpassword,)] = vanilla_pg.safe_psql(
"select rolpassword from pg_catalog.pg_authid where rolname = 'http_auth'"
query(400, "select 1;") # ip address is not allowed
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
)
httpserver.expect_oneshot_request("/cplane/proxy_get_role_secret").respond_with_json(
{
"role_secret": rolpassword,
"allowed_ips": ["8.8.8.8", "127.0.0.1"],
"project_id": "foo-bar",
}
)
with httpserver.wait() as waiting:
static_proxy.http_query(
"select 1;",
[],
user="http_auth",
password="http",
expected_code=200,
)
assert waiting.result
query(200, "select 1;") # should work now

View File

@@ -0,0 +1 @@
generated via `poetry run stubgen -p h2 -o test_runner/stubs`

View File

View File

@@ -0,0 +1,42 @@
from _typeshed import Incomplete
class _BooleanConfigOption:
name: Incomplete
attr_name: Incomplete
def __init__(self, name) -> None: ...
def __get__(self, instance, owner): ...
def __set__(self, instance, value) -> None: ...
class DummyLogger:
def __init__(self, *vargs) -> None: ...
def debug(self, *vargs, **kwargs) -> None: ...
def trace(self, *vargs, **kwargs) -> None: ...
class OutputLogger:
file: Incomplete
trace_level: Incomplete
def __init__(self, file: Incomplete | None = ..., trace_level: bool = ...) -> None: ...
def debug(self, fmtstr, *args) -> None: ...
def trace(self, fmtstr, *args) -> None: ...
class H2Configuration:
client_side: Incomplete
validate_outbound_headers: Incomplete
normalize_outbound_headers: Incomplete
validate_inbound_headers: Incomplete
normalize_inbound_headers: Incomplete
logger: Incomplete
def __init__(
self,
client_side: bool = ...,
header_encoding: Incomplete | None = ...,
validate_outbound_headers: bool = ...,
normalize_outbound_headers: bool = ...,
validate_inbound_headers: bool = ...,
normalize_inbound_headers: bool = ...,
logger: Incomplete | None = ...,
) -> None: ...
@property
def header_encoding(self): ...
@header_encoding.setter
def header_encoding(self, value) -> None: ...

View File

@@ -0,0 +1,142 @@
from enum import Enum, IntEnum
from _typeshed import Incomplete
from .config import H2Configuration as H2Configuration
from .errors import ErrorCodes as ErrorCodes
from .events import AlternativeServiceAvailable as AlternativeServiceAvailable
from .events import ConnectionTerminated as ConnectionTerminated
from .events import PingAckReceived as PingAckReceived
from .events import PingReceived as PingReceived
from .events import PriorityUpdated as PriorityUpdated
from .events import RemoteSettingsChanged as RemoteSettingsChanged
from .events import SettingsAcknowledged as SettingsAcknowledged
from .events import UnknownFrameReceived as UnknownFrameReceived
from .events import WindowUpdated as WindowUpdated
from .exceptions import DenialOfServiceError as DenialOfServiceError
from .exceptions import FlowControlError as FlowControlError
from .exceptions import FrameTooLargeError as FrameTooLargeError
from .exceptions import NoAvailableStreamIDError as NoAvailableStreamIDError
from .exceptions import NoSuchStreamError as NoSuchStreamError
from .exceptions import ProtocolError as ProtocolError
from .exceptions import RFC1122Error as RFC1122Error
from .exceptions import StreamClosedError as StreamClosedError
from .exceptions import StreamIDTooLowError as StreamIDTooLowError
from .exceptions import TooManyStreamsError as TooManyStreamsError
from .frame_buffer import FrameBuffer as FrameBuffer
from .settings import SettingCodes as SettingCodes
from .settings import Settings as Settings
from .stream import H2Stream as H2Stream
from .stream import StreamClosedBy as StreamClosedBy
from .utilities import guard_increment_window as guard_increment_window
from .windows import WindowManager as WindowManager
class ConnectionState(Enum):
IDLE: int
CLIENT_OPEN: int
SERVER_OPEN: int
CLOSED: int
class ConnectionInputs(Enum):
SEND_HEADERS: int
SEND_PUSH_PROMISE: int
SEND_DATA: int
SEND_GOAWAY: int
SEND_WINDOW_UPDATE: int
SEND_PING: int
SEND_SETTINGS: int
SEND_RST_STREAM: int
SEND_PRIORITY: int
RECV_HEADERS: int
RECV_PUSH_PROMISE: int
RECV_DATA: int
RECV_GOAWAY: int
RECV_WINDOW_UPDATE: int
RECV_PING: int
RECV_SETTINGS: int
RECV_RST_STREAM: int
RECV_PRIORITY: int
SEND_ALTERNATIVE_SERVICE: int
RECV_ALTERNATIVE_SERVICE: int
class AllowedStreamIDs(IntEnum):
EVEN: int
ODD: int
class H2ConnectionStateMachine:
state: Incomplete
def __init__(self) -> None: ...
def process_input(self, input_): ...
class H2Connection:
DEFAULT_MAX_OUTBOUND_FRAME_SIZE: int
DEFAULT_MAX_INBOUND_FRAME_SIZE: Incomplete
HIGHEST_ALLOWED_STREAM_ID: Incomplete
MAX_WINDOW_INCREMENT: Incomplete
DEFAULT_MAX_HEADER_LIST_SIZE: Incomplete
MAX_CLOSED_STREAMS: Incomplete
state_machine: Incomplete
streams: Incomplete
highest_inbound_stream_id: int
highest_outbound_stream_id: int
encoder: Incomplete
decoder: Incomplete
config: Incomplete
local_settings: Incomplete
remote_settings: Incomplete
outbound_flow_control_window: Incomplete
max_outbound_frame_size: Incomplete
max_inbound_frame_size: Incomplete
incoming_buffer: Incomplete
def __init__(self, config: Incomplete | None = ...) -> None: ...
@property
def open_outbound_streams(self): ...
@property
def open_inbound_streams(self): ...
@property
def inbound_flow_control_window(self): ...
def initiate_connection(self) -> None: ...
def initiate_upgrade_connection(self, settings_header: Incomplete | None = ...): ...
def get_next_available_stream_id(self): ...
def send_headers(
self,
stream_id,
headers,
end_stream: bool = ...,
priority_weight: Incomplete | None = ...,
priority_depends_on: Incomplete | None = ...,
priority_exclusive: Incomplete | None = ...,
) -> None: ...
def send_data(
self, stream_id, data, end_stream: bool = ..., pad_length: Incomplete | None = ...
) -> None: ...
def end_stream(self, stream_id) -> None: ...
def increment_flow_control_window(
self, increment, stream_id: Incomplete | None = ...
) -> None: ...
def push_stream(self, stream_id, promised_stream_id, request_headers) -> None: ...
def ping(self, opaque_data) -> None: ...
def reset_stream(self, stream_id, error_code: int = ...) -> None: ...
def close_connection(
self,
error_code: int = ...,
additional_data: Incomplete | None = ...,
last_stream_id: Incomplete | None = ...,
) -> None: ...
def update_settings(self, new_settings) -> None: ...
def advertise_alternative_service(
self, field_value, origin: Incomplete | None = ..., stream_id: Incomplete | None = ...
) -> None: ...
def prioritize(
self,
stream_id,
weight: Incomplete | None = ...,
depends_on: Incomplete | None = ...,
exclusive: Incomplete | None = ...,
) -> None: ...
def local_flow_control_window(self, stream_id): ...
def remote_flow_control_window(self, stream_id): ...
def acknowledge_received_data(self, acknowledged_size, stream_id) -> None: ...
def data_to_send(self, amount: Incomplete | None = ...): ...
def clear_outbound_data_buffer(self) -> None: ...
def receive_data(self, data): ...

View File

@@ -0,0 +1,17 @@
import enum
class ErrorCodes(enum.IntEnum):
NO_ERROR: int
PROTOCOL_ERROR: int
INTERNAL_ERROR: int
FLOW_CONTROL_ERROR: int
SETTINGS_TIMEOUT: int
STREAM_CLOSED: int
FRAME_SIZE_ERROR: int
REFUSED_STREAM: int
CANCEL: int
COMPRESSION_ERROR: int
CONNECT_ERROR: int
ENHANCE_YOUR_CALM: int
INADEQUATE_SECURITY: int
HTTP_1_1_REQUIRED: int

View File

@@ -0,0 +1,106 @@
from _typeshed import Incomplete
from .settings import ChangedSetting as ChangedSetting
class Event: ...
class RequestReceived(Event):
stream_id: Incomplete
headers: Incomplete
stream_ended: Incomplete
priority_updated: Incomplete
def __init__(self) -> None: ...
class ResponseReceived(Event):
stream_id: Incomplete
headers: Incomplete
stream_ended: Incomplete
priority_updated: Incomplete
def __init__(self) -> None: ...
class TrailersReceived(Event):
stream_id: Incomplete
headers: Incomplete
stream_ended: Incomplete
priority_updated: Incomplete
def __init__(self) -> None: ...
class _HeadersSent(Event): ...
class _ResponseSent(_HeadersSent): ...
class _RequestSent(_HeadersSent): ...
class _TrailersSent(_HeadersSent): ...
class _PushedRequestSent(_HeadersSent): ...
class InformationalResponseReceived(Event):
stream_id: Incomplete
headers: Incomplete
priority_updated: Incomplete
def __init__(self) -> None: ...
class DataReceived(Event):
stream_id: Incomplete
data: Incomplete
flow_controlled_length: Incomplete
stream_ended: Incomplete
def __init__(self) -> None: ...
class WindowUpdated(Event):
stream_id: Incomplete
delta: Incomplete
def __init__(self) -> None: ...
class RemoteSettingsChanged(Event):
changed_settings: Incomplete
def __init__(self) -> None: ...
@classmethod
def from_settings(cls, old_settings, new_settings): ...
class PingReceived(Event):
ping_data: Incomplete
def __init__(self) -> None: ...
class PingAckReceived(Event):
ping_data: Incomplete
def __init__(self) -> None: ...
class StreamEnded(Event):
stream_id: Incomplete
def __init__(self) -> None: ...
class StreamReset(Event):
stream_id: Incomplete
error_code: Incomplete
remote_reset: bool
def __init__(self) -> None: ...
class PushedStreamReceived(Event):
pushed_stream_id: Incomplete
parent_stream_id: Incomplete
headers: Incomplete
def __init__(self) -> None: ...
class SettingsAcknowledged(Event):
changed_settings: Incomplete
def __init__(self) -> None: ...
class PriorityUpdated(Event):
stream_id: Incomplete
weight: Incomplete
depends_on: Incomplete
exclusive: Incomplete
def __init__(self) -> None: ...
class ConnectionTerminated(Event):
error_code: Incomplete
last_stream_id: Incomplete
additional_data: Incomplete
def __init__(self) -> None: ...
class AlternativeServiceAvailable(Event):
origin: Incomplete
field_value: Incomplete
def __init__(self) -> None: ...
class UnknownFrameReceived(Event):
frame: Incomplete
def __init__(self) -> None: ...

View File

@@ -0,0 +1,48 @@
from _typeshed import Incomplete
class H2Error(Exception): ...
class ProtocolError(H2Error):
error_code: Incomplete
class FrameTooLargeError(ProtocolError):
error_code: Incomplete
class FrameDataMissingError(ProtocolError):
error_code: Incomplete
class TooManyStreamsError(ProtocolError): ...
class FlowControlError(ProtocolError):
error_code: Incomplete
class StreamIDTooLowError(ProtocolError):
stream_id: Incomplete
max_stream_id: Incomplete
def __init__(self, stream_id, max_stream_id) -> None: ...
class NoAvailableStreamIDError(ProtocolError): ...
class NoSuchStreamError(ProtocolError):
stream_id: Incomplete
def __init__(self, stream_id) -> None: ...
class StreamClosedError(NoSuchStreamError):
stream_id: Incomplete
error_code: Incomplete
def __init__(self, stream_id) -> None: ...
class InvalidSettingsValueError(ProtocolError, ValueError):
error_code: Incomplete
def __init__(self, msg, error_code) -> None: ...
class InvalidBodyLengthError(ProtocolError):
expected_length: Incomplete
actual_length: Incomplete
def __init__(self, expected, actual) -> None: ...
class UnsupportedFrameError(ProtocolError): ...
class RFC1122Error(H2Error): ...
class DenialOfServiceError(ProtocolError):
error_code: Incomplete

View File

@@ -0,0 +1,19 @@
from .exceptions import (
FrameDataMissingError as FrameDataMissingError,
)
from .exceptions import (
FrameTooLargeError as FrameTooLargeError,
)
from .exceptions import (
ProtocolError as ProtocolError,
)
CONTINUATION_BACKLOG: int
class FrameBuffer:
data: bytes
max_frame_size: int
def __init__(self, server: bool = ...) -> None: ...
def add_data(self, data) -> None: ...
def __iter__(self): ...
def __next__(self): ...

View File

@@ -0,0 +1,61 @@
import enum
from collections.abc import MutableMapping
from typing import Any
from _typeshed import Incomplete
from h2.errors import ErrorCodes as ErrorCodes
from h2.exceptions import InvalidSettingsValueError as InvalidSettingsValueError
class SettingCodes(enum.IntEnum):
HEADER_TABLE_SIZE: Incomplete
ENABLE_PUSH: Incomplete
MAX_CONCURRENT_STREAMS: Incomplete
INITIAL_WINDOW_SIZE: Incomplete
MAX_FRAME_SIZE: Incomplete
MAX_HEADER_LIST_SIZE: Incomplete
ENABLE_CONNECT_PROTOCOL: Incomplete
class ChangedSetting:
setting: Incomplete
original_value: Incomplete
new_value: Incomplete
def __init__(self, setting, original_value, new_value) -> None: ...
class Settings(MutableMapping[str, Any]):
def __init__(self, client: bool = ..., initial_values: Incomplete | None = ...) -> None: ...
def acknowledge(self): ...
@property
def header_table_size(self): ...
@header_table_size.setter
def header_table_size(self, value) -> None: ...
@property
def enable_push(self): ...
@enable_push.setter
def enable_push(self, value) -> None: ...
@property
def initial_window_size(self): ...
@initial_window_size.setter
def initial_window_size(self, value) -> None: ...
@property
def max_frame_size(self): ...
@max_frame_size.setter
def max_frame_size(self, value) -> None: ...
@property
def max_concurrent_streams(self): ...
@max_concurrent_streams.setter
def max_concurrent_streams(self, value) -> None: ...
@property
def max_header_list_size(self): ...
@max_header_list_size.setter
def max_header_list_size(self, value) -> None: ...
@property
def enable_connect_protocol(self): ...
@enable_connect_protocol.setter
def enable_connect_protocol(self, value) -> None: ...
def __getitem__(self, key): ...
def __setitem__(self, key, value) -> None: ...
def __delitem__(self, key) -> None: ...
def __iter__(self): ...
def __len__(self) -> int: ...
def __eq__(self, other): ...
def __ne__(self, other): ...

View File

@@ -0,0 +1,184 @@
from enum import Enum, IntEnum
from _typeshed import Incomplete
from .errors import ErrorCodes as ErrorCodes
from .events import (
AlternativeServiceAvailable as AlternativeServiceAvailable,
)
from .events import (
DataReceived as DataReceived,
)
from .events import (
InformationalResponseReceived as InformationalResponseReceived,
)
from .events import (
PushedStreamReceived as PushedStreamReceived,
)
from .events import (
RequestReceived as RequestReceived,
)
from .events import (
ResponseReceived as ResponseReceived,
)
from .events import (
StreamEnded as StreamEnded,
)
from .events import (
StreamReset as StreamReset,
)
from .events import (
TrailersReceived as TrailersReceived,
)
from .events import (
WindowUpdated as WindowUpdated,
)
from .exceptions import (
FlowControlError as FlowControlError,
)
from .exceptions import (
InvalidBodyLengthError as InvalidBodyLengthError,
)
from .exceptions import (
ProtocolError as ProtocolError,
)
from .exceptions import (
StreamClosedError as StreamClosedError,
)
from .utilities import (
HeaderValidationFlags as HeaderValidationFlags,
)
from .utilities import (
authority_from_headers as authority_from_headers,
)
from .utilities import (
extract_method_header as extract_method_header,
)
from .utilities import (
guard_increment_window as guard_increment_window,
)
from .utilities import (
is_informational_response as is_informational_response,
)
from .utilities import (
normalize_inbound_headers as normalize_inbound_headers,
)
from .utilities import (
normalize_outbound_headers as normalize_outbound_headers,
)
from .utilities import (
validate_headers as validate_headers,
)
from .utilities import (
validate_outbound_headers as validate_outbound_headers,
)
from .windows import WindowManager as WindowManager
class StreamState(IntEnum):
IDLE: int
RESERVED_REMOTE: int
RESERVED_LOCAL: int
OPEN: int
HALF_CLOSED_REMOTE: int
HALF_CLOSED_LOCAL: int
CLOSED: int
class StreamInputs(Enum):
SEND_HEADERS: int
SEND_PUSH_PROMISE: int
SEND_RST_STREAM: int
SEND_DATA: int
SEND_WINDOW_UPDATE: int
SEND_END_STREAM: int
RECV_HEADERS: int
RECV_PUSH_PROMISE: int
RECV_RST_STREAM: int
RECV_DATA: int
RECV_WINDOW_UPDATE: int
RECV_END_STREAM: int
RECV_CONTINUATION: int
SEND_INFORMATIONAL_HEADERS: int
RECV_INFORMATIONAL_HEADERS: int
SEND_ALTERNATIVE_SERVICE: int
RECV_ALTERNATIVE_SERVICE: int
UPGRADE_CLIENT: int
UPGRADE_SERVER: int
class StreamClosedBy(Enum):
SEND_END_STREAM: int
RECV_END_STREAM: int
SEND_RST_STREAM: int
RECV_RST_STREAM: int
STREAM_OPEN: Incomplete
class H2StreamStateMachine:
state: Incomplete
stream_id: Incomplete
client: Incomplete
headers_sent: Incomplete
trailers_sent: Incomplete
headers_received: Incomplete
trailers_received: Incomplete
stream_closed_by: Incomplete
def __init__(self, stream_id) -> None: ...
def process_input(self, input_): ...
def request_sent(self, previous_state): ...
def response_sent(self, previous_state): ...
def request_received(self, previous_state): ...
def response_received(self, previous_state): ...
def data_received(self, previous_state): ...
def window_updated(self, previous_state): ...
def stream_half_closed(self, previous_state): ...
def stream_ended(self, previous_state): ...
def stream_reset(self, previous_state): ...
def send_new_pushed_stream(self, previous_state): ...
def recv_new_pushed_stream(self, previous_state): ...
def send_push_promise(self, previous_state): ...
def recv_push_promise(self, previous_state): ...
def send_end_stream(self, previous_state) -> None: ...
def send_reset_stream(self, previous_state) -> None: ...
def reset_stream_on_error(self, previous_state) -> None: ...
def recv_on_closed_stream(self, previous_state) -> None: ...
def send_on_closed_stream(self, previous_state) -> None: ...
def recv_push_on_closed_stream(self, previous_state) -> None: ...
def send_push_on_closed_stream(self, previous_state) -> None: ...
def send_informational_response(self, previous_state): ...
def recv_informational_response(self, previous_state): ...
def recv_alt_svc(self, previous_state): ...
def send_alt_svc(self, previous_state) -> None: ...
class H2Stream:
state_machine: Incomplete
stream_id: Incomplete
max_outbound_frame_size: Incomplete
request_method: Incomplete
outbound_flow_control_window: Incomplete
config: Incomplete
def __init__(self, stream_id, config, inbound_window_size, outbound_window_size) -> None: ...
@property
def inbound_flow_control_window(self): ...
@property
def open(self): ...
@property
def closed(self): ...
@property
def closed_by(self): ...
def upgrade(self, client_side) -> None: ...
def send_headers(self, headers, encoder, end_stream: bool = ...): ...
def push_stream_in_band(self, related_stream_id, headers, encoder): ...
def locally_pushed(self): ...
def send_data(self, data, end_stream: bool = ..., pad_length: Incomplete | None = ...): ...
def end_stream(self): ...
def advertise_alternative_service(self, field_value): ...
def increase_flow_control_window(self, increment): ...
def receive_push_promise_in_band(self, promised_stream_id, headers, header_encoding): ...
def remotely_pushed(self, pushed_headers): ...
def receive_headers(self, headers, end_stream, header_encoding): ...
def receive_data(self, data, end_stream, flow_control_len): ...
def receive_window_update(self, increment): ...
def receive_continuation(self) -> None: ...
def receive_alt_svc(self, frame): ...
def reset_stream(self, error_code: int = ...): ...
def stream_reset(self, frame): ...
def acknowledge_received_data(self, acknowledged_size): ...

View File

@@ -0,0 +1,25 @@
from typing import NamedTuple
from _typeshed import Incomplete
from .exceptions import FlowControlError as FlowControlError
from .exceptions import ProtocolError as ProtocolError
UPPER_RE: Incomplete
CONNECTION_HEADERS: Incomplete
def extract_method_header(headers): ...
def is_informational_response(headers): ...
def guard_increment_window(current, increment): ...
def authority_from_headers(headers): ...
class HeaderValidationFlags(NamedTuple):
is_client: Incomplete
is_trailer: Incomplete
is_response_header: Incomplete
is_push_promise: Incomplete
def validate_headers(headers, hdr_validation_flags): ...
def normalize_outbound_headers(headers, hdr_validation_flags): ...
def normalize_inbound_headers(headers, hdr_validation_flags): ...
def validate_outbound_headers(headers, hdr_validation_flags): ...

View File

@@ -0,0 +1,13 @@
from _typeshed import Incomplete
from .exceptions import FlowControlError as FlowControlError
LARGEST_FLOW_CONTROL_WINDOW: Incomplete
class WindowManager:
max_window_size: Incomplete
current_window_size: Incomplete
def __init__(self, max_window_size) -> None: ...
def window_consumed(self, size) -> None: ...
def window_opened(self, size) -> None: ...
def process_bytes(self, size): ...