Compare commits

..

8 Commits

Author SHA1 Message Date
Arseny Sher
1bf5e07da1 Try to enable a custom postgres_exporter query. 2023-11-16 19:35:04 +01:00
Arseny Sher
87de91b004 Add pg_wait_sampling extension. 2023-11-16 19:34:48 +01:00
Arthur Petukhovsky
2f217f9ebd Print pid to the logs 2023-11-16 19:34:32 +01:00
Anna Khanova
71491dd467 Fix build 2023-11-16 19:34:16 +01:00
Anna Khanova
67e791c4ec Fmt 2023-11-16 19:34:02 +01:00
Anna Khanova
517782ab94 Log pid in proxy 2023-11-16 19:33:54 +01:00
Em Sharnoff
d0a842a509 Update vm-builder to v0.19.0 and move its customization here (#5783)
ref neondatabase/autoscaling#600 for more
2023-11-16 18:17:42 +01:00
khanova
6b82f22ada Collect number of connections by sni type (#5867)
## Problem

We don't know the number of users with the different kind of
authentication: ["sni", "endpoint in options" (A and B from
[here](https://neon.tech/docs/connect/connection-errors)),
"password_hack"]

## Summary of changes

Collect metrics by sni kind.
2023-11-16 12:19:13 +00:00
17 changed files with 280 additions and 475 deletions

View File

@@ -852,7 +852,7 @@ jobs:
run:
shell: sh -eu {0}
env:
VM_BUILDER_VERSION: v0.18.5
VM_BUILDER_VERSION: v0.19.0
steps:
- name: Checkout
@@ -874,8 +874,7 @@ jobs:
- name: Build vm image
run: |
./vm-builder \
-enable-file-cache \
-cgroup-uid=postgres \
-spec=vm-image-spec.yaml \
-src=369495373322.dkr.ecr.eu-central-1.amazonaws.com/compute-node-${{ matrix.version }}:${{needs.tag.outputs.build-tag}} \
-dst=369495373322.dkr.ecr.eu-central-1.amazonaws.com/vm-compute-node-${{ matrix.version }}:${{needs.tag.outputs.build-tag}}

91
Cargo.lock generated
View File

@@ -274,7 +274,7 @@ dependencies = [
"hex",
"http",
"hyper",
"ring 0.16.20",
"ring",
"time",
"tokio",
"tower",
@@ -703,7 +703,7 @@ dependencies = [
"bytes",
"dyn-clone",
"futures",
"getrandom 0.2.11",
"getrandom 0.2.9",
"http-types",
"log",
"paste",
@@ -863,22 +863,6 @@ dependencies = [
"which",
]
[[package]]
name = "biscuit"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e28fc7c56c61743a01d0d1b73e4fed68b8a4f032ea3a2d4bb8c6520a33fc05a"
dependencies = [
"chrono",
"data-encoding",
"num-bigint",
"num-traits",
"once_cell",
"ring 0.17.5",
"serde",
"serde_json",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@@ -961,12 +945,11 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
version = "1.0.83"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0"
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
dependencies = [
"jobserver",
"libc",
]
[[package]]
@@ -1863,9 +1846,9 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.11"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4"
dependencies = [
"cfg-if",
"js-sys",
@@ -2359,7 +2342,7 @@ checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378"
dependencies = [
"base64 0.21.1",
"pem 1.1.1",
"ring 0.16.20",
"ring",
"serde",
"serde_json",
"simple_asn1",
@@ -2399,9 +2382,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.150"
version = "0.2.144"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c"
checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1"
[[package]]
name = "libloading"
@@ -2708,7 +2691,7 @@ checksum = "c38841cdd844847e3e7c8d29cef9dcfed8877f8f56f9071f77843ecf3baf937f"
dependencies = [
"base64 0.13.1",
"chrono",
"getrandom 0.2.11",
"getrandom 0.2.9",
"http",
"rand 0.8.5",
"serde",
@@ -3491,7 +3474,6 @@ dependencies = [
"anyhow",
"async-trait",
"base64 0.13.1",
"biscuit",
"bstr",
"bytes",
"chrono",
@@ -3637,7 +3619,7 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom 0.2.11",
"getrandom 0.2.9",
]
[[package]]
@@ -3678,7 +3660,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4954fbc00dcd4d8282c987710e50ba513d351400dbdd00e803a05172a90d8976"
dependencies = [
"pem 2.0.1",
"ring 0.16.20",
"ring",
"time",
"yasna",
]
@@ -3848,7 +3830,7 @@ dependencies = [
"async-trait",
"chrono",
"futures",
"getrandom 0.2.11",
"getrandom 0.2.9",
"http",
"hyper",
"parking_lot 0.11.2",
@@ -3869,7 +3851,7 @@ checksum = "1b97ad83c2fc18113346b7158d79732242002427c30f620fa817c1f32901e0a8"
dependencies = [
"anyhow",
"async-trait",
"getrandom 0.2.11",
"getrandom 0.2.9",
"matchit",
"opentelemetry",
"reqwest",
@@ -3900,25 +3882,11 @@ dependencies = [
"libc",
"once_cell",
"spin 0.5.2",
"untrusted 0.7.1",
"untrusted",
"web-sys",
"winapi",
]
[[package]]
name = "ring"
version = "0.17.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b"
dependencies = [
"cc",
"getrandom 0.2.11",
"libc",
"spin 0.9.8",
"untrusted 0.9.0",
"windows-sys 0.48.0",
]
[[package]]
name = "routerify"
version = "3.0.0"
@@ -4035,7 +4003,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb"
dependencies = [
"log",
"ring 0.16.20",
"ring",
"rustls-webpki 0.101.4",
"sct",
]
@@ -4067,8 +4035,8 @@ version = "0.100.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab"
dependencies = [
"ring 0.16.20",
"untrusted 0.7.1",
"ring",
"untrusted",
]
[[package]]
@@ -4077,8 +4045,8 @@ version = "0.101.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d"
dependencies = [
"ring 0.16.20",
"untrusted 0.7.1",
"ring",
"untrusted",
]
[[package]]
@@ -4223,8 +4191,8 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
dependencies = [
"ring 0.16.20",
"untrusted 0.7.1",
"ring",
"untrusted",
]
[[package]]
@@ -4343,7 +4311,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99dc599bd6646884fc403d593cdcb9816dd67c50cff3271c01ff123617908dcd"
dependencies = [
"debugid",
"getrandom 0.2.11",
"getrandom 0.2.9",
"hex",
"serde",
"serde_json",
@@ -4389,7 +4357,6 @@ version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"indexmap",
"itoa",
"ryu",
"serde",
@@ -4993,7 +4960,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f"
dependencies = [
"futures",
"ring 0.16.20",
"ring",
"rustls",
"tokio",
"tokio-postgres",
@@ -5449,12 +5416,6 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.7.1"
@@ -5556,7 +5517,7 @@ version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2"
dependencies = [
"getrandom 0.2.11",
"getrandom 0.2.9",
"serde",
]
@@ -6049,7 +6010,7 @@ dependencies = [
"regex",
"regex-syntax 0.7.2",
"reqwest",
"ring 0.16.20",
"ring",
"rustls",
"scopeguard",
"serde",

View File

@@ -714,6 +714,23 @@ RUN wget https://github.com/pksunkara/pgx_ulid/archive/refs/tags/v0.1.3.tar.gz -
cargo pgrx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
#########################################################################################
#
# Layer "pg-wait-sampling-pg-build"
# compile pg_wait_sampling extension
#
#########################################################################################
FROM build-deps AS pg-wait-sampling-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PATH "/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/postgrespro/pg_wait_sampling/archive/refs/tags/v1.1.5.tar.gz -O pg_wait_sampling.tar.gz && \
echo 'a03da6a413f5652ce470a3635ed6ebba528c74cb26aa4cfced8aff8a8441f81ec6dd657ff62cd6ce96a4e6ce02cad9f2519ae9525367ece60497aa20faafde5c pg_wait_sampling.tar.gz' | sha512sum -c && \
mkdir pg_wait_sampling-src && cd pg_wait_sampling-src && tar xvzf ../pg_wait_sampling.tar.gz --strip-components=1 -C . && \
make USE_PGXS=1 -j $(getconf _NPROCESSORS_ONLN) && \
make USE_PGXS=1 -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pg_wait_sampling.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -750,6 +767,7 @@ COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-wait-sampling-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \

View File

@@ -68,7 +68,7 @@ webpki-roots.workspace = true
x509-parser.workspace = true
native-tls.workspace = true
postgres-native-tls.workspace = true
biscuit = { version = "0.7",features = [] }
workspace_hack.workspace = true
tokio-util.workspace = true

View File

@@ -3,10 +3,8 @@ mod hacks;
mod link;
pub use link::LinkAuthError;
use serde::{Deserialize, Serialize};
use tokio_postgres::config::AuthKeys;
use crate::console::provider::neon::UserRowLevel;
use crate::proxy::{handle_try_wake, retry_after, LatencyTimer};
use crate::{
auth::{self, ClientCredentials},
@@ -321,41 +319,4 @@ impl BackendType<'_, ClientCredentials<'_>> {
Test(x) => x.wake_compute().map(Some),
}
}
/// Get the password for the RLS user
pub async fn ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
dbname: String,
username: String,
policies: Vec<Policy>,
) -> anyhow::Result<UserRowLevel> {
use BackendType::*;
match self {
Console(api, creds) => {
api.ensure_row_level(extra, creds, dbname, username, policies)
.await
}
Postgres(api, creds) => {
api.ensure_row_level(extra, creds, dbname, username, policies)
.await
}
Link(_) => Err(anyhow::anyhow!("not on link")),
Test(_) => Err(anyhow::anyhow!("not on test")),
}
}
}
// TODO(conrad): policies can be quite complex. Figure out how to configure this
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct Policy {
pub table: String,
pub column: String,
}
// enum PolicyType {
// ForSelect(),
// ForUpdate()
// }

View File

@@ -1,7 +1,9 @@
//! User credentials used in authentication.
use crate::{
auth::password_hack::parse_endpoint_param, error::UserFacingError, proxy::neon_options,
auth::password_hack::parse_endpoint_param,
error::UserFacingError,
proxy::{neon_options, NUM_CONNECTION_ACCEPTED_BY_SNI},
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -124,6 +126,22 @@ impl<'a> ClientCredentials<'a> {
.transpose()?;
info!(user, project = project.as_deref(), "credentials");
if sni.is_some() {
info!("Connection with sni");
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["sni"])
.inc();
} else if project.is_some() {
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["no_sni"])
.inc();
info!("Connection without sni");
} else {
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["password_hack"])
.inc();
info!("Connection with password hack");
}
let cache_key = format!(
"{}{}",

View File

@@ -1,11 +1,9 @@
pub mod mock;
pub mod neon;
use self::neon::UserRowLevel;
use super::messages::MetricsAuxInfo;
use crate::{
auth::{backend::Policy, ClientCredentials},
auth::ClientCredentials,
cache::{timed_lru, TimedLru},
compute, scram,
};
@@ -250,16 +248,6 @@ pub trait Api {
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
/// Get the password for the RLS user
async fn ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
dbname: String,
username: String,
policies: Vec<Policy>,
) -> anyhow::Result<UserRowLevel>;
}
/// Various caches for [`console`](super).

View File

@@ -2,16 +2,9 @@
use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
neon::UserRowLevel,
AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{
auth::{backend::Policy, ClientCredentials},
compute,
error::io_error,
scram,
url::ApiUrl,
};
use crate::{auth::ClientCredentials, compute, error::io_error, scram, url::ApiUrl};
use async_trait::async_trait;
use futures::TryFutureExt;
use thiserror::Error;
@@ -128,18 +121,6 @@ impl super::Api for Api {
.map_ok(CachedNodeInfo::new_uncached)
.await
}
/// Get the password for the RLS user
async fn ensure_row_level(
&self,
_extra: &ConsoleReqExtra<'_>,
_creds: &ClientCredentials,
_dbname: String,
_username: String,
_policies: Vec<Policy>,
) -> anyhow::Result<UserRowLevel> {
Err(anyhow::anyhow!("unimplemented"))
}
}
fn parse_md5(input: &str) -> Option<[u8; 16]> {

View File

@@ -5,13 +5,9 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
use crate::{
auth::{backend::Policy, ClientCredentials},
compute, http, scram,
};
use crate::{auth::ClientCredentials, compute, http, scram};
use async_trait::async_trait;
use futures::TryFutureExt;
use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
@@ -143,84 +139,6 @@ impl Api {
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
dbname: String,
username: String,
policies: Vec<Policy>,
) -> anyhow::Result<UserRowLevel> {
let project = creds.project().expect("impossible");
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
.endpoint
.post("proxy_ensure_role_level_sec")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", extra.session_id)])
// .query(&[
// ("application_name", extra.application_name),
// ("project", Some(project)),
// ("dbname", Some(&dbname)),
// ("username", Some(&username)),
// ("options", extra.options),
// ])
.json(&EnsureRowLevelReq {
project: project.to_owned(),
targets: policies
.into_iter()
.map(|p| Target {
database_name: dbname.clone(),
table_name: p.table,
row_level_user_id: username.clone(),
role_name: "enduser".to_owned(),
column_name: p.column,
})
.collect(),
})
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let response = self.endpoint.execute(request).await?;
info!(duration = ?start.elapsed(), "received http response");
let mut body = parse_body::<UserRowLevel>(response).await?;
// hack
body.username = body.username.to_lowercase();
// info!(user = %body.username, pw=%body.password, "please don't merge this in production");
Ok(body)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
}
#[derive(Serialize)]
struct EnsureRowLevelReq {
project: String,
targets: Vec<Target>,
}
#[derive(Serialize)]
struct Target {
database_name: String,
table_name: String,
row_level_user_id: String,
role_name: String,
column_name: String,
}
#[derive(Deserialize)]
pub struct UserRowLevel {
pub username: String,
pub password: String,
}
#[async_trait]
@@ -270,20 +188,6 @@ impl super::Api for Api {
Ok(cached)
}
/// Get the password for the RLS user
#[tracing::instrument(skip_all)]
async fn ensure_row_level(
&self,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials,
dbname: String,
username: String,
policies: Vec<Policy>,
) -> anyhow::Result<UserRowLevel> {
self.do_ensure_row_level(extra, creds, dbname, username, policies)
.await
}
}
/// Parse http response body, taking status code into account.

View File

@@ -88,14 +88,6 @@ impl Endpoint {
self.client.get(url.into_inner())
}
/// Return a [builder](RequestBuilder) for a `POST` request,
/// appending a single `path` segment to the base endpoint URL.
pub fn post(&self, path: &str) -> RequestBuilder {
let mut url = self.endpoint.clone();
url.path_segments_mut().push(path);
self.client.post(url.into_inner())
}
/// Execute a [request](reqwest::Request).
pub async fn execute(&self, request: Request) -> Result<Response, Error> {
self.client.execute(request).await

View File

@@ -129,6 +129,15 @@ pub static RATE_LIMITER_LIMIT: Lazy<IntGaugeVec> = Lazy::new(|| {
.unwrap()
});
pub static NUM_CONNECTION_ACCEPTED_BY_SNI: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_accepted_connections_by_sni",
"Number of connections (per sni).",
&["kind"],
)
.unwrap()
});
pub struct LatencyTimer {
// time since the stopwatch was started
start: Option<Instant>,

View File

@@ -3,12 +3,10 @@
//! Handles both SQL over HTTP and SQL over Websockets.
mod conn_pool;
pub mod jwt_auth;
mod sql_over_http;
mod websocket;
use anyhow::bail;
use dashmap::DashMap;
use hyper::StatusCode;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
@@ -33,8 +31,6 @@ use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
use self::jwt_auth::JWKSetCaches;
pub async fn task_main(
config: &'static ProxyConfig,
ws_listener: TcpListener,
@@ -45,9 +41,6 @@ pub async fn task_main(
}
let conn_pool = conn_pool::GlobalConnPool::new(config);
let jwk_cache_pool = Arc::new(JWKSetCaches {
map: DashMap::new(),
});
// shutdown the connection pool
tokio::spawn({
@@ -92,7 +85,6 @@ pub async fn task_main(
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let jwk_cache_pool = jwk_cache_pool.clone();
async move {
let peer_addr = match client_addr {
@@ -104,20 +96,13 @@ pub async fn task_main(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
let jwk_cache_pool = jwk_cache_pool.clone();
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
request_handler(
req,
config,
conn_pool,
jwk_cache_pool,
cancel_map,
session_id,
sni_name,
req, config, conn_pool, cancel_map, session_id, sni_name,
)
.instrument(info_span!(
"serverless",
@@ -182,7 +167,6 @@ async fn request_handler(
mut request: Request<Body>,
config: &'static ProxyConfig,
conn_pool: Arc<conn_pool::GlobalConnPool>,
jwk_cache_pool: Arc<JWKSetCaches>,
cancel_map: Arc<CancelMap>,
session_id: uuid::Uuid,
sni_hostname: Option<String>,
@@ -220,7 +204,6 @@ async fn request_handler(
request,
sni_hostname,
conn_pool,
jwk_cache_pool,
session_id,
&config.http_config,
)
@@ -231,7 +214,7 @@ async fn request_handler(
.header("Access-Control-Allow-Origin", "*")
.header(
"Access-Control-Allow-Headers",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In, Authorization",
"Neon-Connection-String, Neon-Raw-Text-Output, Neon-Array-Mode, Neon-Pool-Opt-In",
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code

View File

@@ -21,8 +21,7 @@ use tokio::time;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
use crate::{
auth::{self, backend::Policy},
console::{self, provider::neon::UserRowLevel},
auth, console,
proxy::{
neon_options, LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER,
NUM_DB_CONNECTIONS_OPENED_COUNTER,
@@ -46,8 +45,6 @@ pub struct ConnInfo {
pub hostname: String,
pub password: String,
pub options: Option<String>,
/// row level security mode enabled
pub policies: Option<Vec<Policy>>,
}
impl ConnInfo {
@@ -213,12 +210,7 @@ impl GlobalConnPool {
client.session.send(session_id)?;
latency_timer.pool_hit();
latency_timer.success();
return Ok(Client {
conn_id: client.conn_id,
inner: Some(client),
span: Span::current(),
pool,
});
return Ok(Client::new(client, pool).await);
}
} else {
let conn_id = uuid::Uuid::new_v4();
@@ -266,15 +258,11 @@ impl GlobalConnPool {
_ => {}
}
new_client.map(|inner| Client {
conn_id: inner.conn_id,
inner: Some(inner),
span: Span::current(),
pool,
})
// new_client.map(|inner| Client::new(inner, pool).await)
Ok(Client::new(new_client?, pool).await)
}
fn put(&self, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
fn put(&self, conn_info: &ConnInfo, client: ClientInner, pid: i32) -> anyhow::Result<()> {
let conn_id = client.conn_id;
// We want to hold this open while we return. This ensures that the pool can't close
@@ -318,9 +306,9 @@ impl GlobalConnPool {
// do logging outside of the mutex
if returned {
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}, pid={pid}");
} else {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}, pid={pid}");
}
Ok(())
@@ -368,7 +356,6 @@ struct TokioMechanism<'a> {
conn_info: &'a ConnInfo,
session_id: uuid::Uuid,
conn_id: uuid::Uuid,
row_level: Option<UserRowLevel>,
}
#[async_trait]
@@ -388,7 +375,6 @@ impl ConnectMechanism for TokioMechanism<'_> {
timeout,
self.conn_id,
self.session_id,
&self.row_level,
)
.await
}
@@ -436,26 +422,11 @@ async fn connect_to_compute(
.await?
.context("missing cache entry from wake_compute")?;
let mut row_level = None;
if let Some(policies) = &conn_info.policies {
row_level = Some(
creds
.ensure_row_level(
&extra,
conn_info.dbname.to_owned(),
conn_info.username.to_owned(),
policies.clone(),
)
.await?,
);
}
crate::proxy::connect_to_compute(
&TokioMechanism {
conn_id,
conn_info,
session_id,
row_level,
},
node_info,
&extra,
@@ -471,24 +442,12 @@ async fn connect_to_compute_once(
timeout: time::Duration,
conn_id: uuid::Uuid,
mut session: uuid::Uuid,
row_level: &Option<UserRowLevel>,
) -> Result<ClientInner, tokio_postgres::Error> {
let mut config = (*node_info.config).clone();
let username = row_level
.as_ref()
.map(|r| &r.username)
.unwrap_or(&conn_info.username);
info!(%username, dbname = %conn_info.dbname, "connecting");
let (client, mut connection) = config
.user(username)
.password(
row_level
.as_ref()
.map(|r| &r.password)
.unwrap_or(&conn_info.password),
)
.user(&conn_info.username)
.password(&conn_info.password)
.dbname(&conn_info.dbname)
.connect_timeout(timeout)
.connect(tokio_postgres::NoTls)
@@ -560,6 +519,22 @@ struct ClientInner {
conn_id: uuid::Uuid,
}
impl ClientInner {
pub async fn get_pid(&mut self) -> anyhow::Result<i32> {
let rows = self.inner.query("select pg_backend_pid();", &[]).await?;
if rows.len() != 1 {
Err(anyhow::anyhow!(
"expected 1 row from pg_backend_pid(), got {}",
rows.len()
))
} else {
let pid = rows[0].get(0);
info!(%pid, "got pid");
Ok(pid)
}
}
}
impl Client {
pub fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone())
@@ -571,6 +546,7 @@ pub struct Client {
span: Span,
inner: Option<ClientInner>,
pool: Option<(ConnInfo, Arc<GlobalConnPool>)>,
pid: i32,
}
pub struct Discard<'a> {
@@ -579,12 +555,25 @@ pub struct Discard<'a> {
}
impl Client {
pub(self) async fn new(
mut inner: ClientInner,
pool: Option<(ConnInfo, Arc<GlobalConnPool>)>,
) -> Self {
Self {
conn_id: inner.conn_id,
pid: inner.get_pid().await.unwrap_or(-1),
inner: Some(inner),
span: Span::current(),
pool,
}
}
pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
let Self {
inner,
pool,
conn_id,
span: _,
pid: _,
} = self;
(
&mut inner
@@ -641,10 +630,11 @@ impl Drop for Client {
.expect("client inner should not be removed");
if let Some((conn_info, conn_pool)) = self.pool.take() {
let current_span = self.span.clone();
let pid = self.pid;
// return connection to the pool
tokio::task::spawn_blocking(move || {
let _span = current_span.enter();
let _ = conn_pool.put(&conn_info, client);
let _ = conn_pool.put(&conn_info, client, pid);
});
}
}

View File

@@ -1,98 +0,0 @@
// https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json
use std::sync::Arc;
use anyhow::{bail, Context};
use biscuit::{
jwk::{JWKSet, JWK},
jws, CompactPart,
};
use dashmap::DashMap;
use reqwest::{IntoUrl, Url};
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::RwLock;
pub struct JWKSetCaches {
pub map: DashMap<Url, Arc<JWKSetCache>>,
}
impl JWKSetCaches {
pub async fn get_cache(&self, url: impl IntoUrl) -> anyhow::Result<Arc<JWKSetCache>> {
let url = url.into_url()?;
if let Some(x) = self.map.get(&url) {
return Ok(x.clone());
}
let cache = JWKSetCache::new(url.clone()).await?;
let cache = Arc::new(cache);
self.map.insert(url, cache.clone());
Ok(cache)
}
}
pub struct JWKSetCache {
url: Url,
current: RwLock<biscuit::jwk::JWKSet<()>>,
}
impl JWKSetCache {
pub async fn new(url: impl IntoUrl) -> anyhow::Result<Self> {
let url = url.into_url()?;
let current = reqwest::get(url.clone()).await?.json().await?;
Ok(Self {
url,
current: RwLock::new(current),
})
}
pub async fn get(&self, kid: &str) -> anyhow::Result<JWK<()>> {
let current = self.current.read().await.clone();
if let Some(key) = current.find(kid) {
return Ok(key.clone());
}
let new = reqwest::get(self.url.clone()).await?.json().await?;
if new == current {
bail!("not found")
}
*self.current.write().await = new;
current.find(kid).cloned().context("not found")
}
pub async fn decode<T, H>(
&self,
token: &jws::Compact<T, H>,
) -> anyhow::Result<jws::Compact<T, H>>
where
T: CompactPart,
H: Serialize + DeserializeOwned,
{
let current = self.current.read().await.clone();
match token.decode_with_jwks(&current, None) {
Ok(t) => Ok(t),
Err(biscuit::errors::Error::ValidationError(
biscuit::errors::ValidationError::KeyNotFound,
)) => {
let new: JWKSet<()> = reqwest::get(self.url.clone()).await?.json().await?;
if new == current {
bail!("not found")
}
*self.current.write().await = new.clone();
token.decode_with_jwks(&new, None).context("error")
// current.find(kid).cloned().context("not found")
}
Err(e) => Err(e.into()),
}
}
}
#[cfg(test)]
mod tests {
use super::JWKSetCache;
#[tokio::test]
async fn jwkset() {
let cache =
JWKSetCache::new("https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json")
.await
.unwrap();
dbg!(cache.get("ins_2YFechxysnwZcZN6TDHEz6u6w6v").await.unwrap());
}
}

View File

@@ -1,20 +1,15 @@
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use biscuit::JWT;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
use hyper::header;
use hyper::header::AUTHORIZATION;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use serde_json::Map;
use serde_json::Value;
@@ -31,13 +26,11 @@ use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::Policy;
use crate::config::HttpConfig;
use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
use super::jwt_auth::JWKSetCaches;
#[derive(serde::Deserialize)]
struct QueryData {
@@ -125,10 +118,9 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
}
}
async fn get_conn_info(
jwk_cache_pool: &JWKSetCaches,
fn get_conn_info(
headers: &HeaderMap,
sni_hostname: &str,
sni_hostname: Option<String>,
) -> Result<ConnInfo, anyhow::Error> {
let connection_string = headers
.get("Neon-Connection-String")
@@ -152,42 +144,18 @@ async fn get_conn_info(
.next()
.ok_or(anyhow::anyhow!("invalid database name"))?;
let mut password = "";
let mut policies = None;
let authorization = headers.get(AUTHORIZATION);
let username = if let Some(auth) = authorization {
// TODO: introduce control plane API to fetch this
let jwks_url = match sni_hostname {
"ep-flat-night-23370355.cloud.krypton.aws.neon.build" => {
"https://adapted-gorilla-88.clerk.accounts.dev/.well-known/jwks.json"
}
_ => anyhow::bail!("jwt auth not supported"),
};
let jwk_cache = jwk_cache_pool.get_cache(jwks_url).await?;
let username = connection_url.username();
if username.is_empty() {
return Err(anyhow::anyhow!("missing username"));
}
let auth = auth.to_str()?;
let token = auth.strip_prefix("Bearer ").context("bad token")?;
let jwt: JWT<NeonFields, ()> = JWT::new_encoded(token);
let token = jwk_cache.decode(&jwt).await?;
let payload = token.payload().unwrap();
policies = Some(payload.private.policies.clone());
payload
.registered
.subject
.as_deref()
.context("missing user id")?
.to_owned()
} else {
password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
let password = connection_url
.password()
.ok_or(anyhow::anyhow!("no password"))?;
let u = connection_url.username();
if u.is_empty() {
return Err(anyhow::anyhow!("missing username"));
}
u.to_owned()
};
// TLS certificate selector now based on SNI hostname, so if we are running here
// we are sure that SNI hostname is set to one of the configured domain names.
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
let hostname = connection_url
.host_str()
@@ -218,8 +186,7 @@ async fn get_conn_info(
}
Ok(ConnInfo {
username,
policies,
username: username.to_owned(),
dbname: dbname.to_owned(),
hostname: hostname.to_owned(),
password: password.to_owned(),
@@ -232,13 +199,12 @@ pub async fn handle(
request: Request<Body>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
jwk_cache_pool: Arc<JWKSetCaches>,
session_id: uuid::Uuid,
config: &'static HttpConfig,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.sql_over_http_timeout,
handle_inner(request, sni_hostname, conn_pool, jwk_cache_pool, session_id),
handle_inner(request, sni_hostname, conn_pool, session_id),
)
.await;
let mut response = match result {
@@ -289,7 +255,6 @@ async fn handle_inner(
request: Request<Body>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
jwk_cache_pool: Arc<JWKSetCaches>,
session_id: uuid::Uuid,
) -> anyhow::Result<Response<Body>> {
NUM_CONNECTIONS_ACCEPTED_COUNTER
@@ -299,15 +264,11 @@ async fn handle_inner(
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
}
// TLS certificate selector now based on SNI hostname, so if we are running here
// we are sure that SNI hostname is set to one of the configured domain names.
let sni_hostname = sni_hostname.ok_or(anyhow::anyhow!("no SNI hostname set"))?;
//
// Determine the destination and connection params
//
let headers = request.headers();
let conn_info = get_conn_info(&jwk_cache_pool, headers, &sni_hostname).await?;
let conn_info = get_conn_info(headers, sni_hostname)?;
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
@@ -736,11 +697,6 @@ fn _pg_array_parse(
Ok((Value::Array(entries), 0))
}
#[derive(Serialize, Deserialize)]
pub struct NeonFields {
policies: Vec<Policy>,
}
#[cfg(test)]
mod tests {
use super::*;

143
vm-image-spec.yaml Normal file
View File

@@ -0,0 +1,143 @@
# Supplemental file for neondatabase/autoscaling's vm-builder, for producing the VM compute image.
---
commands:
- name: cgconfigparser
user: root
sysvInitAction: sysinit
shell: "cgconfigparser -l /etc/cgconfig.conf -s 1664"
- name: pgbouncer
user: nobody
sysvInitAction: respawn
shell: "/usr/local/bin/pgbouncer /etc/pgbouncer.ini"
- name: postgres-exporter
user: nobody
sysvInitAction: respawn
shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres" /bin/postgres_exporter --extend.query-path /etc/postgres_exporter_queries.yml'
shutdownHook: |
su -p postgres --session-command '/usr/local/bin/pg_ctl stop -D /var/db/postgres/compute/pgdata -m fast --wait -t 10'
files:
- filename: pgbouncer.ini
content: |
[databases]
*=host=localhost port=5432 auth_user=cloud_admin
[pgbouncer]
listen_port=6432
listen_addr=0.0.0.0
auth_type=scram-sha-256
auth_user=cloud_admin
auth_dbname=postgres
client_tls_sslmode=disable
server_tls_sslmode=disable
pool_mode=transaction
max_client_conn=10000
default_pool_size=16
max_prepared_statements=0
- filename: cgconfig.conf
content: |
# Configuration for cgroups in VM compute nodes
group neon-postgres {
perm {
admin {
uid = postgres;
}
task {
gid = users;
}
}
memory {}
}
- filename: postgres_exporter_queries.yml
content: |
postgres_exporter_pg_database_size:
query: "SELECT pg_database.datname, pg_database_size(pg_database.datname) as bytes, 42 as fourtytwo FROM pg_database"
cache_seconds: 30
metrics:
- datname:
usage: "LABEL"
description: "Name of the database"
- bytes:
usage: "GAUGE"
description: "Disk space used by the database"
- fourtytwo:
usage: "GAUGE"
description: "fourtytwo"
build: |
# Build cgroup-tools
#
# At time of writing (2023-03-14), debian bullseye has a version of cgroup-tools (technically
# libcgroup) that doesn't support cgroup v2 (version 0.41-11). Unfortunately, the vm-monitor
# requires cgroup v2, so we'll build cgroup-tools ourselves.
FROM debian:bullseye-slim as libcgroup-builder
ENV LIBCGROUP_VERSION v2.0.3
RUN set -exu \
&& apt update \
&& apt install --no-install-recommends -y \
git \
ca-certificates \
automake \
cmake \
make \
gcc \
byacc \
flex \
libtool \
libpam0g-dev \
&& git clone --depth 1 -b $LIBCGROUP_VERSION https://github.com/libcgroup/libcgroup \
&& INSTALL_DIR="/libcgroup-install" \
&& mkdir -p "$INSTALL_DIR/bin" "$INSTALL_DIR/include" \
&& cd libcgroup \
# extracted from bootstrap.sh, with modified flags:
&& (test -d m4 || mkdir m4) \
&& autoreconf -fi \
&& rm -rf autom4te.cache \
&& CFLAGS="-O3" ./configure --prefix="$INSTALL_DIR" --sysconfdir=/etc --localstatedir=/var --enable-opaque-hierarchy="name=systemd" \
# actually build the thing...
&& make install
FROM quay.io/prometheuscommunity/postgres-exporter:v0.12.0 AS postgres-exporter
# Build pgbouncer
#
FROM debian:bullseye-slim AS pgbouncer
RUN set -e \
&& apt-get update \
&& apt-get install -y \
curl \
build-essential \
pkg-config \
libevent-dev \
libssl-dev
ENV PGBOUNCER_VERSION 1.21.0
ENV PGBOUNCER_GITPATH 1_21_0
RUN set -e \
&& curl -sfSL https://github.com/pgbouncer/pgbouncer/releases/download/pgbouncer_${PGBOUNCER_GITPATH}/pgbouncer-${PGBOUNCER_VERSION}.tar.gz -o pgbouncer-${PGBOUNCER_VERSION}.tar.gz \
&& tar xzvf pgbouncer-${PGBOUNCER_VERSION}.tar.gz \
&& cd pgbouncer-${PGBOUNCER_VERSION} \
&& LDFLAGS=-static ./configure --prefix=/usr/local/pgbouncer --without-openssl \
&& make -j $(nproc) \
&& make install
merge: |
# tweak nofile limits
RUN set -e \
&& echo 'fs.file-max = 1048576' >>/etc/sysctl.conf \
&& test ! -e /etc/security || ( \
echo '* - nofile 1048576' >>/etc/security/limits.conf \
&& echo 'root - nofile 1048576' >>/etc/security/limits.conf \
)
COPY cgconfig.conf /etc/cgconfig.conf
COPY pgbouncer.ini /etc/pgbouncer.ini
COPY postgres_exporter_queries.yml /etc/postgres_exporter_queries.yml
RUN set -e \
&& chown postgres:postgres /etc/pgbouncer.ini \
&& chmod 0644 /etc/pgbouncer.ini \
&& chmod 0644 /etc/cgconfig.conf \
&& chmod 0644 /etc/postgres_exporter_queries.yml
COPY --from=libcgroup-builder /libcgroup-install/bin/* /usr/bin/
COPY --from=libcgroup-builder /libcgroup-install/lib/* /usr/lib/
COPY --from=libcgroup-builder /libcgroup-install/sbin/* /usr/sbin/
COPY --from=postgres-exporter /bin/postgres_exporter /bin/postgres_exporter
COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/pgbouncer

View File

@@ -54,7 +54,7 @@ ring = { version = "0.16", features = ["std"] }
rustls = { version = "0.21", features = ["dangerous_configuration"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["preserve_order", "raw_value"] }
serde_json = { version = "1", features = ["raw_value"] }
smallvec = { version = "1", default-features = false, features = ["write"] }
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }