mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
[proxy] handle options request in rest broker (cors headers) (#12744)
## Problem rest broker needs to respond with the correct cors headers for the api to be usable from other domains ## Summary of changes added a code path in rest broker to handle the OPTIONS requests --------- Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
This commit is contained in:
@@ -5,12 +5,17 @@ use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::Method;
|
||||
use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
|
||||
use http::header::{
|
||||
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
|
||||
AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN,
|
||||
};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use http_body_util::{BodyExt, Empty, Full};
|
||||
use http_utils::error::ApiError;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::http::{HeaderName, HeaderValue};
|
||||
use hyper::http::response::Builder;
|
||||
use hyper::http::{HeaderMap, HeaderName, HeaderValue};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use indexmap::IndexMap;
|
||||
use moka::sync::Cache;
|
||||
@@ -67,6 +72,15 @@ use crate::util::deserialize_json_string;
|
||||
|
||||
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
|
||||
const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
|
||||
const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
|
||||
// CORS headers values
|
||||
const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
|
||||
HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
|
||||
const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
|
||||
const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
|
||||
"Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
|
||||
);
|
||||
const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
|
||||
|
||||
// A wrapper around the DbSchema that allows for self-referencing
|
||||
#[self_referencing]
|
||||
@@ -137,6 +151,8 @@ pub struct ApiConfig {
|
||||
pub role_claim_key: String,
|
||||
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
|
||||
pub db_extra_search_path: Option<Vec<String>>,
|
||||
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
|
||||
pub server_cors_allowed_origins: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
|
||||
@@ -165,7 +181,13 @@ impl DbSchemaCache {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_cached_or_remote(
|
||||
pub fn get_cached(
|
||||
&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
|
||||
count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
|
||||
}
|
||||
pub async fn get_remote(
|
||||
&self,
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
auth_header: &HeaderValue,
|
||||
@@ -174,47 +196,42 @@ impl DbSchemaCache {
|
||||
ctx: &RequestContext,
|
||||
config: &'static ProxyConfig,
|
||||
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
||||
let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id));
|
||||
match cache_result {
|
||||
Some(v) => Ok(v),
|
||||
None => {
|
||||
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
||||
let remote_value = self
|
||||
.get_remote(auth_header, connection_string, client, ctx, config)
|
||||
.await;
|
||||
let (api_config, schema_owned) = match remote_value {
|
||||
Ok((api_config, schema_owned)) => (api_config, schema_owned),
|
||||
Err(e @ RestError::SchemaTooLarge) => {
|
||||
// for the case where the schema is too large, we cache an empty dummy value
|
||||
// all the other requests will fail without triggering the introspection query
|
||||
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
|
||||
.map_err(|e| JsonDeserialize { source: e })?;
|
||||
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
||||
let remote_value = self
|
||||
.internal_get_remote(auth_header, connection_string, client, ctx, config)
|
||||
.await;
|
||||
let (api_config, schema_owned) = match remote_value {
|
||||
Ok((api_config, schema_owned)) => (api_config, schema_owned),
|
||||
Err(e @ RestError::SchemaTooLarge) => {
|
||||
// for the case where the schema is too large, we cache an empty dummy value
|
||||
// all the other requests will fail without triggering the introspection query
|
||||
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
|
||||
.map_err(|e| JsonDeserialize { source: e })?;
|
||||
|
||||
let api_config = ApiConfig {
|
||||
db_schemas: vec![],
|
||||
db_anon_role: None,
|
||||
db_max_rows: None,
|
||||
db_allowed_select_functions: vec![],
|
||||
role_claim_key: String::new(),
|
||||
db_extra_search_path: None,
|
||||
};
|
||||
let value = Arc::new((api_config, schema_owned));
|
||||
count_cache_insert(CacheKind::Schema);
|
||||
self.0.insert(endpoint_id.clone(), value);
|
||||
return Err(e);
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
let api_config = ApiConfig {
|
||||
db_schemas: vec![],
|
||||
db_anon_role: None,
|
||||
db_max_rows: None,
|
||||
db_allowed_select_functions: vec![],
|
||||
role_claim_key: String::new(),
|
||||
db_extra_search_path: None,
|
||||
server_cors_allowed_origins: None,
|
||||
};
|
||||
let value = Arc::new((api_config, schema_owned));
|
||||
count_cache_insert(CacheKind::Schema);
|
||||
self.0.insert(endpoint_id.clone(), value.clone());
|
||||
Ok(value)
|
||||
self.0.insert(endpoint_id.clone(), value);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
let value = Arc::new((api_config, schema_owned));
|
||||
count_cache_insert(CacheKind::Schema);
|
||||
self.0.insert(endpoint_id.clone(), value.clone());
|
||||
Ok(value)
|
||||
}
|
||||
pub async fn get_remote(
|
||||
async fn internal_get_remote(
|
||||
&self,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
@@ -531,7 +548,7 @@ pub(crate) async fn handle(
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
let response = match result {
|
||||
Ok(r) => {
|
||||
ctx.set_success();
|
||||
|
||||
@@ -640,9 +657,6 @@ pub(crate) async fn handle(
|
||||
}
|
||||
};
|
||||
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -722,6 +736,37 @@ async fn handle_inner(
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_common_cors_headers(
|
||||
response: &mut Builder,
|
||||
request_headers: &HeaderMap,
|
||||
allowed_origins: Option<&Vec<String>>,
|
||||
) {
|
||||
let request_origin = request_headers
|
||||
.get(ORIGIN)
|
||||
.map(|v| v.to_str().unwrap_or(""));
|
||||
|
||||
let response_allow_origin = match (request_origin, allowed_origins) {
|
||||
(Some(or), Some(allowed_origins)) => {
|
||||
if allowed_origins.iter().any(|o| o == or) {
|
||||
Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
(Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(h) = response.headers_mut() {
|
||||
h.insert(
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
|
||||
);
|
||||
if let Some(origin) = response_allow_origin {
|
||||
h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_rest_inner(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -733,12 +778,6 @@ async fn handle_rest_inner(
|
||||
jwt: String,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
|
||||
// validate the jwt token
|
||||
let jwt_parsed = backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let db_schema_cache =
|
||||
config
|
||||
.rest_config
|
||||
@@ -754,28 +793,83 @@ async fn handle_rest_inner(
|
||||
message: "Failed to get endpoint cache key".to_string(),
|
||||
}))?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
|
||||
let (parts, originial_body) = request.into_parts();
|
||||
|
||||
// try and get the cached entry for this endpoint
|
||||
// it contains the api config and the introspected db schema
|
||||
let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
|
||||
|
||||
let allowed_origins = cached_entry
|
||||
.as_ref()
|
||||
.and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
|
||||
|
||||
let mut response = Response::builder();
|
||||
apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
|
||||
|
||||
// handle the OPTIONS request
|
||||
if parts.method == Method::OPTIONS {
|
||||
let allowed_headers = parts
|
||||
.headers
|
||||
.get(ACCESS_CONTROL_REQUEST_HEADERS)
|
||||
.and_then(|a| a.to_str().ok())
|
||||
.filter(|v| !v.is_empty())
|
||||
.map_or_else(
|
||||
|| "Authorization".to_string(),
|
||||
|v| format!("{v}, Authorization"),
|
||||
);
|
||||
return response
|
||||
.status(StatusCode::OK)
|
||||
.header(
|
||||
ACCESS_CONTROL_ALLOW_METHODS,
|
||||
ACCESS_CONTROL_ALLOW_METHODS_VALUE,
|
||||
)
|
||||
.header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
|
||||
.header(
|
||||
ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
HeaderValue::from_str(&allowed_headers)
|
||||
.unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
|
||||
)
|
||||
.header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
|
||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| {
|
||||
RestError::SubzeroCore(InternalError {
|
||||
message: e.to_string(),
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// validate the jwt token
|
||||
let jwt_parsed = backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get(AUTHORIZATION)
|
||||
.ok_or(RestError::SubzeroCore(InternalError {
|
||||
message: "Authorization header is required".to_string(),
|
||||
}))?;
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
|
||||
let entry = db_schema_cache
|
||||
.get_cached_or_remote(
|
||||
&endpoint_cache_key,
|
||||
auth_header,
|
||||
connection_string,
|
||||
&mut client,
|
||||
ctx,
|
||||
config,
|
||||
)
|
||||
.await?;
|
||||
let entry = match cached_entry {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
// if not cached, get the remote entry (will run the introspection query)
|
||||
db_schema_cache
|
||||
.get_remote(
|
||||
&endpoint_cache_key,
|
||||
auth_header,
|
||||
connection_string,
|
||||
&mut client,
|
||||
ctx,
|
||||
config,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
let (api_config, db_schema_owned) = entry.as_ref();
|
||||
|
||||
let db_schema = db_schema_owned.borrow_schema();
|
||||
|
||||
let db_schemas = &api_config.db_schemas; // list of schemas available for the api
|
||||
@@ -999,8 +1093,8 @@ async fn handle_rest_inner(
|
||||
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
|
||||
|
||||
// send the request to the local proxy
|
||||
let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
|
||||
let (parts, body) = response.into_parts();
|
||||
let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
|
||||
let (response_parts, body) = proxy_response.into_parts();
|
||||
|
||||
let max_response = config.http_config.max_response_size_bytes;
|
||||
let bytes = read_body_with_limit(body, max_response)
|
||||
@@ -1009,7 +1103,7 @@ async fn handle_rest_inner(
|
||||
|
||||
// if the response status is greater than 399, then it is an error
|
||||
// FIXME: check if there are other error codes or shapes of the response
|
||||
if parts.status.as_u16() > 399 {
|
||||
if response_parts.status.as_u16() > 399 {
|
||||
// turn this postgres error from the json into PostgresError
|
||||
let postgres_error = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||
@@ -1175,7 +1269,7 @@ async fn handle_rest_inner(
|
||||
.boxed();
|
||||
|
||||
// build the response
|
||||
let mut response = Response::builder()
|
||||
response = response
|
||||
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
|
||||
.header(CONTENT_TYPE, http_content_type);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user