Compare commits

..

2 Commits

Author SHA1 Message Date
BodoBolero
be2ad49a62 move benchmarks to another repo 2025-08-12 19:03:20 +02:00
Ruslan Talpa
d96cea1917 [proxy] handle options request in rest broker (cors headers) (#12744)
## Problem
rest broker needs to respond with the correct cors headers for the api
to be usable from other domains

## Summary of changes
added a code path in rest broker to handle the OPTIONS requests

---------

Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
2025-07-31 13:05:09 +00:00
5 changed files with 168 additions and 126 deletions

View File

@@ -300,7 +300,9 @@ jobs:
benchmarks:
# `!failure() && !cancelled()` is required because the workflow depends on the job that can be skipped: `deploy` in PRs
if: github.ref_name == 'main' || (contains(github.event.pull_request.labels.*.name, 'run-benchmarks') && !failure() && !cancelled())
# if: github.ref_name == 'main' || (contains(github.event.pull_request.labels.*.name, 'run-benchmarks') && !failure() && !cancelled())
# moved to another repo
if: false
needs: [ check-permissions, build-build-tools-image, get-benchmarks-durations, deploy ]
permissions:
id-token: write # aws-actions/configure-aws-credentials

View File

@@ -407,8 +407,8 @@ fn get_database_stats(cli: &mut Client) -> anyhow::Result<(f64, i64)> {
// like `postgres_exporter` use it to query Postgres statistics.
// Use explicit 8 bytes type casts to match Rust types.
let stats = cli.query_one(
"SELECT COALESCE(pg_catalog.sum(active_time), 0.0)::pg_catalog.float8 AS total_active_time,
COALESCE(pg_catalog.sum(sessions), 0)::pg_catalog.int8 AS total_sessions
"SELECT pg_catalog.coalesce(pg_catalog.sum(active_time), 0.0)::pg_catalog.float8 AS total_active_time,
pg_catalog.coalesce(pg_catalog.sum(sessions), 0)::pg_catalog.bigint AS total_sessions
FROM pg_catalog.pg_stat_database
WHERE datname NOT IN (
'postgres',

View File

@@ -241,7 +241,7 @@ impl ComputeControlPlane {
drop_subscriptions_before_start,
grpc,
reconfigure_concurrency: 1,
features: vec![ComputeFeature::ActivityMonitorExperimental],
features: vec![],
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
@@ -263,7 +263,7 @@ impl ComputeControlPlane {
skip_pg_catalog_updates,
drop_subscriptions_before_start,
reconfigure_concurrency: 1,
features: vec![ComputeFeature::ActivityMonitorExperimental],
features: vec![],
cluster: None,
compute_ctl_config,
privileged_role_name,

View File

@@ -5,12 +5,17 @@ use std::sync::Arc;
use bytes::Bytes;
use http::Method;
use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
use http::header::{
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN,
};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_body_util::{BodyExt, Empty, Full};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::http::response::Builder;
use hyper::http::{HeaderMap, HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode};
use indexmap::IndexMap;
use moka::sync::Cache;
@@ -67,6 +72,15 @@ use crate::util::deserialize_json_string;
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
// CORS headers values
const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
"Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
);
const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
// A wrapper around the DbSchema that allows for self-referencing
#[self_referencing]
@@ -137,6 +151,8 @@ pub struct ApiConfig {
pub role_claim_key: String,
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
pub db_extra_search_path: Option<Vec<String>>,
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
pub server_cors_allowed_origins: Option<Vec<String>>,
}
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
@@ -165,7 +181,13 @@ impl DbSchemaCache {
}
}
pub async fn get_cached_or_remote(
pub fn get_cached(
&self,
endpoint_id: &EndpointCacheKey,
) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
}
pub async fn get_remote(
&self,
endpoint_id: &EndpointCacheKey,
auth_header: &HeaderValue,
@@ -174,47 +196,42 @@ impl DbSchemaCache {
ctx: &RequestContext,
config: &'static ProxyConfig,
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id));
match cache_result {
Some(v) => Ok(v),
None => {
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
let remote_value = self
.get_remote(auth_header, connection_string, client, ctx, config)
.await;
let (api_config, schema_owned) = match remote_value {
Ok((api_config, schema_owned)) => (api_config, schema_owned),
Err(e @ RestError::SchemaTooLarge) => {
// for the case where the schema is too large, we cache an empty dummy value
// all the other requests will fail without triggering the introspection query
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
.map_err(|e| JsonDeserialize { source: e })?;
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
let remote_value = self
.internal_get_remote(auth_header, connection_string, client, ctx, config)
.await;
let (api_config, schema_owned) = match remote_value {
Ok((api_config, schema_owned)) => (api_config, schema_owned),
Err(e @ RestError::SchemaTooLarge) => {
// for the case where the schema is too large, we cache an empty dummy value
// all the other requests will fail without triggering the introspection query
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
.map_err(|e| JsonDeserialize { source: e })?;
let api_config = ApiConfig {
db_schemas: vec![],
db_anon_role: None,
db_max_rows: None,
db_allowed_select_functions: vec![],
role_claim_key: String::new(),
db_extra_search_path: None,
};
let value = Arc::new((api_config, schema_owned));
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value);
return Err(e);
}
Err(e) => {
return Err(e);
}
let api_config = ApiConfig {
db_schemas: vec![],
db_anon_role: None,
db_max_rows: None,
db_allowed_select_functions: vec![],
role_claim_key: String::new(),
db_extra_search_path: None,
server_cors_allowed_origins: None,
};
let value = Arc::new((api_config, schema_owned));
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value.clone());
Ok(value)
self.0.insert(endpoint_id.clone(), value);
return Err(e);
}
}
Err(e) => {
return Err(e);
}
};
let value = Arc::new((api_config, schema_owned));
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value.clone());
Ok(value)
}
pub async fn get_remote(
async fn internal_get_remote(
&self,
auth_header: &HeaderValue,
connection_string: &str,
@@ -531,7 +548,7 @@ pub(crate) async fn handle(
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
let response = match result {
Ok(r) => {
ctx.set_success();
@@ -640,9 +657,6 @@ pub(crate) async fn handle(
}
};
response
.headers_mut()
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
Ok(response)
}
@@ -722,6 +736,37 @@ async fn handle_inner(
}
}
fn apply_common_cors_headers(
response: &mut Builder,
request_headers: &HeaderMap,
allowed_origins: Option<&Vec<String>>,
) {
let request_origin = request_headers
.get(ORIGIN)
.map(|v| v.to_str().unwrap_or(""));
let response_allow_origin = match (request_origin, allowed_origins) {
(Some(or), Some(allowed_origins)) => {
if allowed_origins.iter().any(|o| o == or) {
Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
} else {
None
}
}
(Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
_ => None,
};
if let Some(h) = response.headers_mut() {
h.insert(
ACCESS_CONTROL_EXPOSE_HEADERS,
ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
);
if let Some(origin) = response_allow_origin {
h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_rest_inner(
config: &'static ProxyConfig,
@@ -733,12 +778,6 @@ async fn handle_rest_inner(
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
// validate the jwt token
let jwt_parsed = backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?;
let db_schema_cache =
config
.rest_config
@@ -754,28 +793,83 @@ async fn handle_rest_inner(
message: "Failed to get endpoint cache key".to_string(),
}))?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let (parts, originial_body) = request.into_parts();
// try and get the cached entry for this endpoint
// it contains the api config and the introspected db schema
let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
let allowed_origins = cached_entry
.as_ref()
.and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
let mut response = Response::builder();
apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
// handle the OPTIONS request
if parts.method == Method::OPTIONS {
let allowed_headers = parts
.headers
.get(ACCESS_CONTROL_REQUEST_HEADERS)
.and_then(|a| a.to_str().ok())
.filter(|v| !v.is_empty())
.map_or_else(
|| "Authorization".to_string(),
|v| format!("{v}, Authorization"),
);
return response
.status(StatusCode::OK)
.header(
ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_ALLOW_METHODS_VALUE,
)
.header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
.header(
ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_str(&allowed_headers)
.unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
)
.header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| {
RestError::SubzeroCore(InternalError {
message: e.to_string(),
})
});
}
// validate the jwt token
let jwt_parsed = backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?;
let auth_header = parts
.headers
.get(AUTHORIZATION)
.ok_or(RestError::SubzeroCore(InternalError {
message: "Authorization header is required".to_string(),
}))?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let entry = db_schema_cache
.get_cached_or_remote(
&endpoint_cache_key,
auth_header,
connection_string,
&mut client,
ctx,
config,
)
.await?;
let entry = match cached_entry {
Some(e) => e,
None => {
// if not cached, get the remote entry (will run the introspection query)
db_schema_cache
.get_remote(
&endpoint_cache_key,
auth_header,
connection_string,
&mut client,
ctx,
config,
)
.await?
}
};
let (api_config, db_schema_owned) = entry.as_ref();
let db_schema = db_schema_owned.borrow_schema();
let db_schemas = &api_config.db_schemas; // list of schemas available for the api
@@ -999,8 +1093,8 @@ async fn handle_rest_inner(
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
// send the request to the local proxy
let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
let (parts, body) = response.into_parts();
let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
let (response_parts, body) = proxy_response.into_parts();
let max_response = config.http_config.max_response_size_bytes;
let bytes = read_body_with_limit(body, max_response)
@@ -1009,7 +1103,7 @@ async fn handle_rest_inner(
// if the response status is greater than 399, then it is an error
// FIXME: check if there are other error codes or shapes of the response
if parts.status.as_u16() > 399 {
if response_parts.status.as_u16() > 399 {
// turn this postgres error from the json into PostgresError
let postgres_error = serde_json::from_slice(&bytes)
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
@@ -1175,7 +1269,7 @@ async fn handle_rest_inner(
.boxed();
// build the response
let mut response = Response::builder()
response = response
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
.header(CONTENT_TYPE, http_content_type);

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import time
from typing import TYPE_CHECKING
from fixtures.metrics import parse_metrics
@@ -10,13 +9,13 @@ if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
def test_compute_monitor_downtime_calculation(neon_simple_env: NeonEnv):
def test_compute_monitor(neon_simple_env: NeonEnv):
"""
Test that compute_ctl can detect Postgres going down (unresponsive) and
reconnect when it comes back online. Also check that the downtime metrics
are properly emitted.
"""
TEST_DB = "test_compute_monitor_downtime_calculation"
TEST_DB = "test_compute_monitor"
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
@@ -69,56 +68,3 @@ def test_compute_monitor_downtime_calculation(neon_simple_env: NeonEnv):
# Just a sanity check that we log the downtime info
endpoint.log_contains("downtime_info")
def test_compute_monitor_activity(neon_simple_env: NeonEnv):
"""
Test compute monitor correctly detects user activity inside Postgres
and updates last_active timestamp in the /status response.
"""
TEST_DB = "test_compute_monitor_activity_db"
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
with endpoint.cursor() as cursor:
# Create a new database because `postgres` DB is excluded
# from activity monitoring.
cursor.execute(f"CREATE DATABASE {TEST_DB}")
client = endpoint.http_client()
prev_last_active = None
def check_last_active():
nonlocal prev_last_active
with endpoint.cursor(dbname=TEST_DB) as cursor:
# Execute some dummy query to generate 'activity'.
cursor.execute("SELECT * FROM generate_series(1, 10000)")
status = client.status()
assert status["last_active"] is not None
prev_last_active = status["last_active"]
wait_until(check_last_active)
assert prev_last_active is not None
# Sleep for everything to settle down. It's not strictly necessary,
# but should still remove any potential noise and/or prevent test from passing
# even if compute monitor is not working.
time.sleep(3)
with endpoint.cursor(dbname=TEST_DB) as cursor:
cursor.execute("SELECT * FROM generate_series(1, 10000)")
def check_last_active_updated():
nonlocal prev_last_active
status = client.status()
assert status["last_active"] is not None
assert status["last_active"] != prev_last_active
assert status["last_active"] > prev_last_active
wait_until(check_last_active_updated)