mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-16 12:40:36 +00:00
Compare commits
2 Commits
alexk/fix-
...
bodobolero
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be2ad49a62 | ||
|
|
d96cea1917 |
4
.github/workflows/build_and_test.yml
vendored
4
.github/workflows/build_and_test.yml
vendored
@@ -300,7 +300,9 @@ jobs:
|
||||
|
||||
benchmarks:
|
||||
# `!failure() && !cancelled()` is required because the workflow depends on the job that can be skipped: `deploy` in PRs
|
||||
if: github.ref_name == 'main' || (contains(github.event.pull_request.labels.*.name, 'run-benchmarks') && !failure() && !cancelled())
|
||||
# if: github.ref_name == 'main' || (contains(github.event.pull_request.labels.*.name, 'run-benchmarks') && !failure() && !cancelled())
|
||||
# moved to another repo
|
||||
if: false
|
||||
needs: [ check-permissions, build-build-tools-image, get-benchmarks-durations, deploy ]
|
||||
permissions:
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
|
||||
@@ -407,8 +407,8 @@ fn get_database_stats(cli: &mut Client) -> anyhow::Result<(f64, i64)> {
|
||||
// like `postgres_exporter` use it to query Postgres statistics.
|
||||
// Use explicit 8 bytes type casts to match Rust types.
|
||||
let stats = cli.query_one(
|
||||
"SELECT COALESCE(pg_catalog.sum(active_time), 0.0)::pg_catalog.float8 AS total_active_time,
|
||||
COALESCE(pg_catalog.sum(sessions), 0)::pg_catalog.int8 AS total_sessions
|
||||
"SELECT pg_catalog.coalesce(pg_catalog.sum(active_time), 0.0)::pg_catalog.float8 AS total_active_time,
|
||||
pg_catalog.coalesce(pg_catalog.sum(sessions), 0)::pg_catalog.bigint AS total_sessions
|
||||
FROM pg_catalog.pg_stat_database
|
||||
WHERE datname NOT IN (
|
||||
'postgres',
|
||||
|
||||
@@ -241,7 +241,7 @@ impl ComputeControlPlane {
|
||||
drop_subscriptions_before_start,
|
||||
grpc,
|
||||
reconfigure_concurrency: 1,
|
||||
features: vec![ComputeFeature::ActivityMonitorExperimental],
|
||||
features: vec![],
|
||||
cluster: None,
|
||||
compute_ctl_config: compute_ctl_config.clone(),
|
||||
privileged_role_name: privileged_role_name.clone(),
|
||||
@@ -263,7 +263,7 @@ impl ComputeControlPlane {
|
||||
skip_pg_catalog_updates,
|
||||
drop_subscriptions_before_start,
|
||||
reconfigure_concurrency: 1,
|
||||
features: vec![ComputeFeature::ActivityMonitorExperimental],
|
||||
features: vec![],
|
||||
cluster: None,
|
||||
compute_ctl_config,
|
||||
privileged_role_name,
|
||||
|
||||
@@ -5,12 +5,17 @@ use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::Method;
|
||||
use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
|
||||
use http::header::{
|
||||
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
|
||||
AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN,
|
||||
};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http_body_util::{BodyExt, Empty, Full};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use hyper::http::response::Builder;
|
||||
use hyper::http::{HeaderMap, HeaderName, HeaderValue};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use indexmap::IndexMap;
|
||||
use moka::sync::Cache;
|
||||
@@ -67,6 +72,15 @@ use crate::util::deserialize_json_string;
|
||||
|
||||
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
|
||||
const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
|
||||
const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
|
||||
// CORS headers values
|
||||
const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
|
||||
HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
|
||||
const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
|
||||
const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
|
||||
"Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
|
||||
);
|
||||
const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
|
||||
|
||||
// A wrapper around the DbSchema that allows for self-referencing
|
||||
#[self_referencing]
|
||||
@@ -137,6 +151,8 @@ pub struct ApiConfig {
|
||||
pub role_claim_key: String,
|
||||
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
|
||||
pub db_extra_search_path: Option<Vec<String>>,
|
||||
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
|
||||
pub server_cors_allowed_origins: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
|
||||
@@ -165,7 +181,13 @@ impl DbSchemaCache {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_cached_or_remote(
|
||||
pub fn get_cached(
|
||||
&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
|
||||
count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
|
||||
}
|
||||
pub async fn get_remote(
|
||||
&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
auth_header: &HeaderValue,
|
||||
@@ -174,47 +196,42 @@ impl DbSchemaCache {
|
||||
ctx: &RequestContext,
|
||||
config: &'static ProxyConfig,
|
||||
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
||||
let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id));
|
||||
match cache_result {
|
||||
Some(v) => Ok(v),
|
||||
None => {
|
||||
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
||||
let remote_value = self
|
||||
.get_remote(auth_header, connection_string, client, ctx, config)
|
||||
.await;
|
||||
let (api_config, schema_owned) = match remote_value {
|
||||
Ok((api_config, schema_owned)) => (api_config, schema_owned),
|
||||
Err(e @ RestError::SchemaTooLarge) => {
|
||||
// for the case where the schema is too large, we cache an empty dummy value
|
||||
// all the other requests will fail without triggering the introspection query
|
||||
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
|
||||
.map_err(|e| JsonDeserialize { source: e })?;
|
||||
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
||||
let remote_value = self
|
||||
.internal_get_remote(auth_header, connection_string, client, ctx, config)
|
||||
.await;
|
||||
let (api_config, schema_owned) = match remote_value {
|
||||
Ok((api_config, schema_owned)) => (api_config, schema_owned),
|
||||
Err(e @ RestError::SchemaTooLarge) => {
|
||||
// for the case where the schema is too large, we cache an empty dummy value
|
||||
// all the other requests will fail without triggering the introspection query
|
||||
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
|
||||
.map_err(|e| JsonDeserialize { source: e })?;
|
||||
|
||||
let api_config = ApiConfig {
|
||||
db_schemas: vec![],
|
||||
db_anon_role: None,
|
||||
db_max_rows: None,
|
||||
db_allowed_select_functions: vec![],
|
||||
role_claim_key: String::new(),
|
||||
db_extra_search_path: None,
|
||||
};
|
||||
let value = Arc::new((api_config, schema_owned));
|
||||
count_cache_insert(CacheKind::Schema);
|
||||
self.0.insert(endpoint_id.clone(), value);
|
||||
return Err(e);
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
let api_config = ApiConfig {
|
||||
db_schemas: vec![],
|
||||
db_anon_role: None,
|
||||
db_max_rows: None,
|
||||
db_allowed_select_functions: vec![],
|
||||
role_claim_key: String::new(),
|
||||
db_extra_search_path: None,
|
||||
server_cors_allowed_origins: None,
|
||||
};
|
||||
let value = Arc::new((api_config, schema_owned));
|
||||
count_cache_insert(CacheKind::Schema);
|
||||
self.0.insert(endpoint_id.clone(), value.clone());
|
||||
Ok(value)
|
||||
self.0.insert(endpoint_id.clone(), value);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
let value = Arc::new((api_config, schema_owned));
|
||||
count_cache_insert(CacheKind::Schema);
|
||||
self.0.insert(endpoint_id.clone(), value.clone());
|
||||
Ok(value)
|
||||
}
|
||||
pub async fn get_remote(
|
||||
async fn internal_get_remote(
|
||||
&self,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
@@ -531,7 +548,7 @@ pub(crate) async fn handle(
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
let response = match result {
|
||||
Ok(r) => {
|
||||
ctx.set_success();
|
||||
|
||||
@@ -640,9 +657,6 @@ pub(crate) async fn handle(
|
||||
}
|
||||
};
|
||||
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -722,6 +736,37 @@ async fn handle_inner(
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_common_cors_headers(
|
||||
response: &mut Builder,
|
||||
request_headers: &HeaderMap,
|
||||
allowed_origins: Option<&Vec<String>>,
|
||||
) {
|
||||
let request_origin = request_headers
|
||||
.get(ORIGIN)
|
||||
.map(|v| v.to_str().unwrap_or(""));
|
||||
|
||||
let response_allow_origin = match (request_origin, allowed_origins) {
|
||||
(Some(or), Some(allowed_origins)) => {
|
||||
if allowed_origins.iter().any(|o| o == or) {
|
||||
Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
(Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(h) = response.headers_mut() {
|
||||
h.insert(
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
|
||||
);
|
||||
if let Some(origin) = response_allow_origin {
|
||||
h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_rest_inner(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -733,12 +778,6 @@ async fn handle_rest_inner(
|
||||
jwt: String,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
|
||||
// validate the jwt token
|
||||
let jwt_parsed = backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let db_schema_cache =
|
||||
config
|
||||
.rest_config
|
||||
@@ -754,28 +793,83 @@ async fn handle_rest_inner(
|
||||
message: "Failed to get endpoint cache key".to_string(),
|
||||
}))?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
|
||||
let (parts, originial_body) = request.into_parts();
|
||||
|
||||
// try and get the cached entry for this endpoint
|
||||
// it contains the api config and the introspected db schema
|
||||
let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
|
||||
|
||||
let allowed_origins = cached_entry
|
||||
.as_ref()
|
||||
.and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
|
||||
|
||||
let mut response = Response::builder();
|
||||
apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
|
||||
|
||||
// handle the OPTIONS request
|
||||
if parts.method == Method::OPTIONS {
|
||||
let allowed_headers = parts
|
||||
.headers
|
||||
.get(ACCESS_CONTROL_REQUEST_HEADERS)
|
||||
.and_then(|a| a.to_str().ok())
|
||||
.filter(|v| !v.is_empty())
|
||||
.map_or_else(
|
||||
|| "Authorization".to_string(),
|
||||
|v| format!("{v}, Authorization"),
|
||||
);
|
||||
return response
|
||||
.status(StatusCode::OK)
|
||||
.header(
|
||||
ACCESS_CONTROL_ALLOW_METHODS,
|
||||
ACCESS_CONTROL_ALLOW_METHODS_VALUE,
|
||||
)
|
||||
.header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
|
||||
.header(
|
||||
ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
HeaderValue::from_str(&allowed_headers)
|
||||
.unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
|
||||
)
|
||||
.header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
|
||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: e.to_string(),
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// validate the jwt token
|
||||
let jwt_parsed = backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get(AUTHORIZATION)
|
||||
.ok_or(RestError::SubzeroCore(InternalError {
|
||||
message: "Authorization header is required".to_string(),
|
||||
}))?;
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
|
||||
let entry = db_schema_cache
|
||||
.get_cached_or_remote(
|
||||
&endpoint_cache_key,
|
||||
auth_header,
|
||||
connection_string,
|
||||
&mut client,
|
||||
ctx,
|
||||
config,
|
||||
)
|
||||
.await?;
|
||||
let entry = match cached_entry {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
// if not cached, get the remote entry (will run the introspection query)
|
||||
db_schema_cache
|
||||
.get_remote(
|
||||
&endpoint_cache_key,
|
||||
auth_header,
|
||||
connection_string,
|
||||
&mut client,
|
||||
ctx,
|
||||
config,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
let (api_config, db_schema_owned) = entry.as_ref();
|
||||
|
||||
let db_schema = db_schema_owned.borrow_schema();
|
||||
|
||||
let db_schemas = &api_config.db_schemas; // list of schemas available for the api
|
||||
@@ -999,8 +1093,8 @@ async fn handle_rest_inner(
|
||||
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
|
||||
|
||||
// send the request to the local proxy
|
||||
let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
|
||||
let (parts, body) = response.into_parts();
|
||||
let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
|
||||
let (response_parts, body) = proxy_response.into_parts();
|
||||
|
||||
let max_response = config.http_config.max_response_size_bytes;
|
||||
let bytes = read_body_with_limit(body, max_response)
|
||||
@@ -1009,7 +1103,7 @@ async fn handle_rest_inner(
|
||||
|
||||
// if the response status is greater than 399, then it is an error
|
||||
// FIXME: check if there are other error codes or shapes of the response
|
||||
if parts.status.as_u16() > 399 {
|
||||
if response_parts.status.as_u16() > 399 {
|
||||
// turn this postgres error from the json into PostgresError
|
||||
let postgres_error = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
@@ -1175,7 +1269,7 @@ async fn handle_rest_inner(
|
||||
.boxed();
|
||||
|
||||
// build the response
|
||||
let mut response = Response::builder()
|
||||
response = response
|
||||
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
|
||||
.header(CONTENT_TYPE, http_content_type);
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fixtures.metrics import parse_metrics
|
||||
@@ -10,13 +9,13 @@ if TYPE_CHECKING:
|
||||
from fixtures.neon_fixtures import NeonEnv
|
||||
|
||||
|
||||
def test_compute_monitor_downtime_calculation(neon_simple_env: NeonEnv):
|
||||
def test_compute_monitor(neon_simple_env: NeonEnv):
|
||||
"""
|
||||
Test that compute_ctl can detect Postgres going down (unresponsive) and
|
||||
reconnect when it comes back online. Also check that the downtime metrics
|
||||
are properly emitted.
|
||||
"""
|
||||
TEST_DB = "test_compute_monitor_downtime_calculation"
|
||||
TEST_DB = "test_compute_monitor"
|
||||
|
||||
env = neon_simple_env
|
||||
endpoint = env.endpoints.create_start("main")
|
||||
@@ -69,56 +68,3 @@ def test_compute_monitor_downtime_calculation(neon_simple_env: NeonEnv):
|
||||
|
||||
# Just a sanity check that we log the downtime info
|
||||
endpoint.log_contains("downtime_info")
|
||||
|
||||
|
||||
def test_compute_monitor_activity(neon_simple_env: NeonEnv):
|
||||
"""
|
||||
Test compute monitor correctly detects user activity inside Postgres
|
||||
and updates last_active timestamp in the /status response.
|
||||
"""
|
||||
TEST_DB = "test_compute_monitor_activity_db"
|
||||
|
||||
env = neon_simple_env
|
||||
endpoint = env.endpoints.create_start("main")
|
||||
|
||||
with endpoint.cursor() as cursor:
|
||||
# Create a new database because `postgres` DB is excluded
|
||||
# from activity monitoring.
|
||||
cursor.execute(f"CREATE DATABASE {TEST_DB}")
|
||||
|
||||
client = endpoint.http_client()
|
||||
|
||||
prev_last_active = None
|
||||
|
||||
def check_last_active():
|
||||
nonlocal prev_last_active
|
||||
|
||||
with endpoint.cursor(dbname=TEST_DB) as cursor:
|
||||
# Execute some dummy query to generate 'activity'.
|
||||
cursor.execute("SELECT * FROM generate_series(1, 10000)")
|
||||
|
||||
status = client.status()
|
||||
assert status["last_active"] is not None
|
||||
prev_last_active = status["last_active"]
|
||||
|
||||
wait_until(check_last_active)
|
||||
|
||||
assert prev_last_active is not None
|
||||
|
||||
# Sleep for everything to settle down. It's not strictly necessary,
|
||||
# but should still remove any potential noise and/or prevent test from passing
|
||||
# even if compute monitor is not working.
|
||||
time.sleep(3)
|
||||
|
||||
with endpoint.cursor(dbname=TEST_DB) as cursor:
|
||||
cursor.execute("SELECT * FROM generate_series(1, 10000)")
|
||||
|
||||
def check_last_active_updated():
|
||||
nonlocal prev_last_active
|
||||
|
||||
status = client.status()
|
||||
assert status["last_active"] is not None
|
||||
assert status["last_active"] != prev_last_active
|
||||
assert status["last_active"] > prev_last_active
|
||||
|
||||
wait_until(check_last_active_updated)
|
||||
|
||||
Reference in New Issue
Block a user