mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
[proxy] handle options request in rest broker (cors headers) (#12744)
## Problem rest broker needs to respond with the correct cors headers for the api to be usable from other domains ## Summary of changes added a code path in rest broker to handle the OPTIONS requests --------- Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
This commit is contained in:
@@ -5,12 +5,17 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::Method;
|
use http::Method;
|
||||||
use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
|
use http::header::{
|
||||||
|
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||||
|
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
|
||||||
|
AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN,
|
||||||
|
};
|
||||||
use http_body_util::combinators::BoxBody;
|
use http_body_util::combinators::BoxBody;
|
||||||
use http_body_util::{BodyExt, Full};
|
use http_body_util::{BodyExt, Empty, Full};
|
||||||
use http_utils::error::ApiError;
|
use http_utils::error::ApiError;
|
||||||
use hyper::body::Incoming;
|
use hyper::body::Incoming;
|
||||||
use hyper::http::{HeaderName, HeaderValue};
|
use hyper::http::response::Builder;
|
||||||
|
use hyper::http::{HeaderMap, HeaderName, HeaderValue};
|
||||||
use hyper::{Request, Response, StatusCode};
|
use hyper::{Request, Response, StatusCode};
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use moka::sync::Cache;
|
use moka::sync::Cache;
|
||||||
@@ -67,6 +72,15 @@ use crate::util::deserialize_json_string;
|
|||||||
|
|
||||||
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
|
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
|
||||||
const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
|
const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
|
||||||
|
const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
|
||||||
|
// CORS headers values
|
||||||
|
const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
|
||||||
|
HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
|
||||||
|
const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
|
||||||
|
const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
|
||||||
|
"Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
|
||||||
|
);
|
||||||
|
const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
|
||||||
|
|
||||||
// A wrapper around the DbSchema that allows for self-referencing
|
// A wrapper around the DbSchema that allows for self-referencing
|
||||||
#[self_referencing]
|
#[self_referencing]
|
||||||
@@ -137,6 +151,8 @@ pub struct ApiConfig {
|
|||||||
pub role_claim_key: String,
|
pub role_claim_key: String,
|
||||||
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
|
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
|
||||||
pub db_extra_search_path: Option<Vec<String>>,
|
pub db_extra_search_path: Option<Vec<String>>,
|
||||||
|
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
|
||||||
|
pub server_cors_allowed_origins: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
|
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
|
||||||
@@ -165,7 +181,13 @@ impl DbSchemaCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_cached_or_remote(
|
pub fn get_cached(
|
||||||
|
&self,
|
||||||
|
endpoint_id: &EndpointCacheKey,
|
||||||
|
) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
|
||||||
|
count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
|
||||||
|
}
|
||||||
|
pub async fn get_remote(
|
||||||
&self,
|
&self,
|
||||||
endpoint_id: &EndpointCacheKey,
|
endpoint_id: &EndpointCacheKey,
|
||||||
auth_header: &HeaderValue,
|
auth_header: &HeaderValue,
|
||||||
@@ -174,47 +196,42 @@ impl DbSchemaCache {
|
|||||||
ctx: &RequestContext,
|
ctx: &RequestContext,
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
||||||
let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id));
|
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
||||||
match cache_result {
|
let remote_value = self
|
||||||
Some(v) => Ok(v),
|
.internal_get_remote(auth_header, connection_string, client, ctx, config)
|
||||||
None => {
|
.await;
|
||||||
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
let (api_config, schema_owned) = match remote_value {
|
||||||
let remote_value = self
|
Ok((api_config, schema_owned)) => (api_config, schema_owned),
|
||||||
.get_remote(auth_header, connection_string, client, ctx, config)
|
Err(e @ RestError::SchemaTooLarge) => {
|
||||||
.await;
|
// for the case where the schema is too large, we cache an empty dummy value
|
||||||
let (api_config, schema_owned) = match remote_value {
|
// all the other requests will fail without triggering the introspection query
|
||||||
Ok((api_config, schema_owned)) => (api_config, schema_owned),
|
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
|
||||||
Err(e @ RestError::SchemaTooLarge) => {
|
.map_err(|e| JsonDeserialize { source: e })?;
|
||||||
// for the case where the schema is too large, we cache an empty dummy value
|
|
||||||
// all the other requests will fail without triggering the introspection query
|
|
||||||
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
|
|
||||||
.map_err(|e| JsonDeserialize { source: e })?;
|
|
||||||
|
|
||||||
let api_config = ApiConfig {
|
let api_config = ApiConfig {
|
||||||
db_schemas: vec![],
|
db_schemas: vec![],
|
||||||
db_anon_role: None,
|
db_anon_role: None,
|
||||||
db_max_rows: None,
|
db_max_rows: None,
|
||||||
db_allowed_select_functions: vec![],
|
db_allowed_select_functions: vec![],
|
||||||
role_claim_key: String::new(),
|
role_claim_key: String::new(),
|
||||||
db_extra_search_path: None,
|
db_extra_search_path: None,
|
||||||
};
|
server_cors_allowed_origins: None,
|
||||||
let value = Arc::new((api_config, schema_owned));
|
|
||||||
count_cache_insert(CacheKind::Schema);
|
|
||||||
self.0.insert(endpoint_id.clone(), value);
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let value = Arc::new((api_config, schema_owned));
|
let value = Arc::new((api_config, schema_owned));
|
||||||
count_cache_insert(CacheKind::Schema);
|
count_cache_insert(CacheKind::Schema);
|
||||||
self.0.insert(endpoint_id.clone(), value.clone());
|
self.0.insert(endpoint_id.clone(), value);
|
||||||
Ok(value)
|
return Err(e);
|
||||||
}
|
}
|
||||||
}
|
Err(e) => {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let value = Arc::new((api_config, schema_owned));
|
||||||
|
count_cache_insert(CacheKind::Schema);
|
||||||
|
self.0.insert(endpoint_id.clone(), value.clone());
|
||||||
|
Ok(value)
|
||||||
}
|
}
|
||||||
pub async fn get_remote(
|
async fn internal_get_remote(
|
||||||
&self,
|
&self,
|
||||||
auth_header: &HeaderValue,
|
auth_header: &HeaderValue,
|
||||||
connection_string: &str,
|
connection_string: &str,
|
||||||
@@ -531,7 +548,7 @@ pub(crate) async fn handle(
|
|||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||||
|
|
||||||
let mut response = match result {
|
let response = match result {
|
||||||
Ok(r) => {
|
Ok(r) => {
|
||||||
ctx.set_success();
|
ctx.set_success();
|
||||||
|
|
||||||
@@ -640,9 +657,6 @@ pub(crate) async fn handle(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
response
|
|
||||||
.headers_mut()
|
|
||||||
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -722,6 +736,37 @@ async fn handle_inner(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn apply_common_cors_headers(
|
||||||
|
response: &mut Builder,
|
||||||
|
request_headers: &HeaderMap,
|
||||||
|
allowed_origins: Option<&Vec<String>>,
|
||||||
|
) {
|
||||||
|
let request_origin = request_headers
|
||||||
|
.get(ORIGIN)
|
||||||
|
.map(|v| v.to_str().unwrap_or(""));
|
||||||
|
|
||||||
|
let response_allow_origin = match (request_origin, allowed_origins) {
|
||||||
|
(Some(or), Some(allowed_origins)) => {
|
||||||
|
if allowed_origins.iter().any(|o| o == or) {
|
||||||
|
Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
if let Some(h) = response.headers_mut() {
|
||||||
|
h.insert(
|
||||||
|
ACCESS_CONTROL_EXPOSE_HEADERS,
|
||||||
|
ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
|
||||||
|
);
|
||||||
|
if let Some(origin) = response_allow_origin {
|
||||||
|
h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn handle_rest_inner(
|
async fn handle_rest_inner(
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
@@ -733,12 +778,6 @@ async fn handle_rest_inner(
|
|||||||
jwt: String,
|
jwt: String,
|
||||||
backend: Arc<PoolingBackend>,
|
backend: Arc<PoolingBackend>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
|
||||||
// validate the jwt token
|
|
||||||
let jwt_parsed = backend
|
|
||||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
|
||||||
.await
|
|
||||||
.map_err(HttpConnError::from)?;
|
|
||||||
|
|
||||||
let db_schema_cache =
|
let db_schema_cache =
|
||||||
config
|
config
|
||||||
.rest_config
|
.rest_config
|
||||||
@@ -754,28 +793,83 @@ async fn handle_rest_inner(
|
|||||||
message: "Failed to get endpoint cache key".to_string(),
|
message: "Failed to get endpoint cache key".to_string(),
|
||||||
}))?;
|
}))?;
|
||||||
|
|
||||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
|
||||||
|
|
||||||
let (parts, originial_body) = request.into_parts();
|
let (parts, originial_body) = request.into_parts();
|
||||||
|
|
||||||
|
// try and get the cached entry for this endpoint
|
||||||
|
// it contains the api config and the introspected db schema
|
||||||
|
let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
|
||||||
|
|
||||||
|
let allowed_origins = cached_entry
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
|
||||||
|
|
||||||
|
let mut response = Response::builder();
|
||||||
|
apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
|
||||||
|
|
||||||
|
// handle the OPTIONS request
|
||||||
|
if parts.method == Method::OPTIONS {
|
||||||
|
let allowed_headers = parts
|
||||||
|
.headers
|
||||||
|
.get(ACCESS_CONTROL_REQUEST_HEADERS)
|
||||||
|
.and_then(|a| a.to_str().ok())
|
||||||
|
.filter(|v| !v.is_empty())
|
||||||
|
.map_or_else(
|
||||||
|
|| "Authorization".to_string(),
|
||||||
|
|v| format!("{v}, Authorization"),
|
||||||
|
);
|
||||||
|
return response
|
||||||
|
.status(StatusCode::OK)
|
||||||
|
.header(
|
||||||
|
ACCESS_CONTROL_ALLOW_METHODS,
|
||||||
|
ACCESS_CONTROL_ALLOW_METHODS_VALUE,
|
||||||
|
)
|
||||||
|
.header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
|
||||||
|
.header(
|
||||||
|
ACCESS_CONTROL_ALLOW_HEADERS,
|
||||||
|
HeaderValue::from_str(&allowed_headers)
|
||||||
|
.unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
|
||||||
|
)
|
||||||
|
.header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
|
||||||
|
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||||
|
.map_err(|e| {
|
||||||
|
RestError::SubzeroCore(InternalError {
|
||||||
|
message: e.to_string(),
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate the jwt token
|
||||||
|
let jwt_parsed = backend
|
||||||
|
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||||
|
.await
|
||||||
|
.map_err(HttpConnError::from)?;
|
||||||
|
|
||||||
let auth_header = parts
|
let auth_header = parts
|
||||||
.headers
|
.headers
|
||||||
.get(AUTHORIZATION)
|
.get(AUTHORIZATION)
|
||||||
.ok_or(RestError::SubzeroCore(InternalError {
|
.ok_or(RestError::SubzeroCore(InternalError {
|
||||||
message: "Authorization header is required".to_string(),
|
message: "Authorization header is required".to_string(),
|
||||||
}))?;
|
}))?;
|
||||||
|
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||||
|
|
||||||
let entry = db_schema_cache
|
let entry = match cached_entry {
|
||||||
.get_cached_or_remote(
|
Some(e) => e,
|
||||||
&endpoint_cache_key,
|
None => {
|
||||||
auth_header,
|
// if not cached, get the remote entry (will run the introspection query)
|
||||||
connection_string,
|
db_schema_cache
|
||||||
&mut client,
|
.get_remote(
|
||||||
ctx,
|
&endpoint_cache_key,
|
||||||
config,
|
auth_header,
|
||||||
)
|
connection_string,
|
||||||
.await?;
|
&mut client,
|
||||||
|
ctx,
|
||||||
|
config,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
}
|
||||||
|
};
|
||||||
let (api_config, db_schema_owned) = entry.as_ref();
|
let (api_config, db_schema_owned) = entry.as_ref();
|
||||||
|
|
||||||
let db_schema = db_schema_owned.borrow_schema();
|
let db_schema = db_schema_owned.borrow_schema();
|
||||||
|
|
||||||
let db_schemas = &api_config.db_schemas; // list of schemas available for the api
|
let db_schemas = &api_config.db_schemas; // list of schemas available for the api
|
||||||
@@ -999,8 +1093,8 @@ async fn handle_rest_inner(
|
|||||||
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
|
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
|
||||||
|
|
||||||
// send the request to the local proxy
|
// send the request to the local proxy
|
||||||
let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
|
let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
|
||||||
let (parts, body) = response.into_parts();
|
let (response_parts, body) = proxy_response.into_parts();
|
||||||
|
|
||||||
let max_response = config.http_config.max_response_size_bytes;
|
let max_response = config.http_config.max_response_size_bytes;
|
||||||
let bytes = read_body_with_limit(body, max_response)
|
let bytes = read_body_with_limit(body, max_response)
|
||||||
@@ -1009,7 +1103,7 @@ async fn handle_rest_inner(
|
|||||||
|
|
||||||
// if the response status is greater than 399, then it is an error
|
// if the response status is greater than 399, then it is an error
|
||||||
// FIXME: check if there are other error codes or shapes of the response
|
// FIXME: check if there are other error codes or shapes of the response
|
||||||
if parts.status.as_u16() > 399 {
|
if response_parts.status.as_u16() > 399 {
|
||||||
// turn this postgres error from the json into PostgresError
|
// turn this postgres error from the json into PostgresError
|
||||||
let postgres_error = serde_json::from_slice(&bytes)
|
let postgres_error = serde_json::from_slice(&bytes)
|
||||||
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
|
||||||
@@ -1175,7 +1269,7 @@ async fn handle_rest_inner(
|
|||||||
.boxed();
|
.boxed();
|
||||||
|
|
||||||
// build the response
|
// build the response
|
||||||
let mut response = Response::builder()
|
response = response
|
||||||
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
|
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
|
||||||
.header(CONTENT_TYPE, http_content_type);
|
.header(CONTENT_TYPE, http_content_type);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user