diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index 0c3d2c958d..9f98e87272 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -5,12 +5,17 @@ use std::sync::Arc; use bytes::Bytes; use http::Method; -use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST}; +use http::header::{ + ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, + ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW, + AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN, +}; use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full}; +use http_body_util::{BodyExt, Empty, Full}; use http_utils::error::ApiError; use hyper::body::Incoming; -use hyper::http::{HeaderName, HeaderValue}; +use hyper::http::response::Builder; +use hyper::http::{HeaderMap, HeaderName, HeaderValue}; use hyper::{Request, Response, StatusCode}; use indexmap::IndexMap; use moka::sync::Cache; @@ -67,6 +72,15 @@ use crate::util::deserialize_json_string; static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#; const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL; +const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*"); +// CORS headers values +const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue = + HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS"); +const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400"); +const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static( + "Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit", +); +const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization"); // A wrapper around the DbSchema that allows for self-referencing #[self_referencing] @@ -137,6 +151,8 @@ pub struct ApiConfig { pub role_claim_key: String, #[serde(default, deserialize_with = "deserialize_comma_separated_option")] pub db_extra_search_path: Option>, + #[serde(default, deserialize_with = "deserialize_comma_separated_option")] + pub server_cors_allowed_origins: Option>, } // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint @@ -165,7 +181,13 @@ impl DbSchemaCache { } } - pub async fn get_cached_or_remote( + pub fn get_cached( + &self, + endpoint_id: &EndpointCacheKey, + ) -> Option> { + count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id)) + } + pub async fn get_remote( &self, endpoint_id: &EndpointCacheKey, auth_header: &HeaderValue, @@ -174,47 +196,42 @@ impl DbSchemaCache { ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result, RestError> { - let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id)); - match cache_result { - Some(v) => Ok(v), - None => { - info!("db_schema cache miss for endpoint: {:?}", endpoint_id); - let remote_value = self - .get_remote(auth_header, connection_string, client, ctx, config) - .await; - let (api_config, schema_owned) = match remote_value { - Ok((api_config, schema_owned)) => (api_config, schema_owned), - Err(e @ RestError::SchemaTooLarge) => { - // for the case where the schema is too large, we cache an empty dummy value - // all the other requests will fail without triggering the introspection query - let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) - .map_err(|e| JsonDeserialize { source: e })?; + info!("db_schema cache miss for endpoint: {:?}", endpoint_id); + let remote_value = self + .internal_get_remote(auth_header, connection_string, client, ctx, config) + .await; + let (api_config, schema_owned) = match remote_value { + Ok((api_config, schema_owned)) => (api_config, schema_owned), + Err(e @ RestError::SchemaTooLarge) => { + // for the case where the schema is too large, we cache an empty dummy value + // all the other requests will fail without triggering the introspection query + let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) + .map_err(|e| JsonDeserialize { source: e })?; - let api_config = ApiConfig { - db_schemas: vec![], - db_anon_role: None, - db_max_rows: None, - db_allowed_select_functions: vec![], - role_claim_key: String::new(), - db_extra_search_path: None, - }; - let value = Arc::new((api_config, schema_owned)); - count_cache_insert(CacheKind::Schema); - self.0.insert(endpoint_id.clone(), value); - return Err(e); - } - Err(e) => { - return Err(e); - } + let api_config = ApiConfig { + db_schemas: vec![], + db_anon_role: None, + db_max_rows: None, + db_allowed_select_functions: vec![], + role_claim_key: String::new(), + db_extra_search_path: None, + server_cors_allowed_origins: None, }; let value = Arc::new((api_config, schema_owned)); count_cache_insert(CacheKind::Schema); - self.0.insert(endpoint_id.clone(), value.clone()); - Ok(value) + self.0.insert(endpoint_id.clone(), value); + return Err(e); } - } + Err(e) => { + return Err(e); + } + }; + let value = Arc::new((api_config, schema_owned)); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value.clone()); + Ok(value) } - pub async fn get_remote( + async fn internal_get_remote( &self, auth_header: &HeaderValue, connection_string: &str, @@ -531,7 +548,7 @@ pub(crate) async fn handle( ) -> Result>, ApiError> { let result = handle_inner(cancel, config, &ctx, request, backend).await; - let mut response = match result { + let response = match result { Ok(r) => { ctx.set_success(); @@ -640,9 +657,6 @@ pub(crate) async fn handle( } }; - response - .headers_mut() - .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); Ok(response) } @@ -722,6 +736,37 @@ async fn handle_inner( } } +fn apply_common_cors_headers( + response: &mut Builder, + request_headers: &HeaderMap, + allowed_origins: Option<&Vec>, +) { + let request_origin = request_headers + .get(ORIGIN) + .map(|v| v.to_str().unwrap_or("")); + + let response_allow_origin = match (request_origin, allowed_origins) { + (Some(or), Some(allowed_origins)) => { + if allowed_origins.iter().any(|o| o == or) { + Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS)) + } else { + None + } + } + (Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS), + _ => None, + }; + if let Some(h) = response.headers_mut() { + h.insert( + ACCESS_CONTROL_EXPOSE_HEADERS, + ACCESS_CONTROL_EXPOSE_HEADERS_VALUE, + ); + if let Some(origin) = response_allow_origin { + h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + } + } +} + #[allow(clippy::too_many_arguments)] async fn handle_rest_inner( config: &'static ProxyConfig, @@ -733,12 +778,6 @@ async fn handle_rest_inner( jwt: String, backend: Arc, ) -> Result>, RestError> { - // validate the jwt token - let jwt_parsed = backend - .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) - .await - .map_err(HttpConnError::from)?; - let db_schema_cache = config .rest_config @@ -754,28 +793,83 @@ async fn handle_rest_inner( message: "Failed to get endpoint cache key".to_string(), }))?; - let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; - let (parts, originial_body) = request.into_parts(); + // try and get the cached entry for this endpoint + // it contains the api config and the introspected db schema + let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key); + + let allowed_origins = cached_entry + .as_ref() + .and_then(|arc| arc.0.server_cors_allowed_origins.as_ref()); + + let mut response = Response::builder(); + apply_common_cors_headers(&mut response, &parts.headers, allowed_origins); + + // handle the OPTIONS request + if parts.method == Method::OPTIONS { + let allowed_headers = parts + .headers + .get(ACCESS_CONTROL_REQUEST_HEADERS) + .and_then(|a| a.to_str().ok()) + .filter(|v| !v.is_empty()) + .map_or_else( + || "Authorization".to_string(), + |v| format!("{v}, Authorization"), + ); + return response + .status(StatusCode::OK) + .header( + ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_ALLOW_METHODS_VALUE, + ) + .header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE) + .header( + ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_str(&allowed_headers) + .unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE), + ) + .header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE) + .body(Empty::new().map_err(|x| match x {}).boxed()) + .map_err(|e| { + RestError::SubzeroCore(InternalError { + message: e.to_string(), + }) + }); + } + + // validate the jwt token + let jwt_parsed = backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::from)?; + let auth_header = parts .headers .get(AUTHORIZATION) .ok_or(RestError::SubzeroCore(InternalError { message: "Authorization header is required".to_string(), }))?; + let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; - let entry = db_schema_cache - .get_cached_or_remote( - &endpoint_cache_key, - auth_header, - connection_string, - &mut client, - ctx, - config, - ) - .await?; + let entry = match cached_entry { + Some(e) => e, + None => { + // if not cached, get the remote entry (will run the introspection query) + db_schema_cache + .get_remote( + &endpoint_cache_key, + auth_header, + connection_string, + &mut client, + ctx, + config, + ) + .await? + } + }; let (api_config, db_schema_owned) = entry.as_ref(); + let db_schema = db_schema_owned.borrow_schema(); let db_schemas = &api_config.db_schemas; // list of schemas available for the api @@ -999,8 +1093,8 @@ async fn handle_rest_inner( let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly? // send the request to the local proxy - let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; - let (parts, body) = response.into_parts(); + let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; + let (response_parts, body) = proxy_response.into_parts(); let max_response = config.http_config.max_response_size_bytes; let bytes = read_body_with_limit(body, max_response) @@ -1009,7 +1103,7 @@ async fn handle_rest_inner( // if the response status is greater than 399, then it is an error // FIXME: check if there are other error codes or shapes of the response - if parts.status.as_u16() > 399 { + if response_parts.status.as_u16() > 399 { // turn this postgres error from the json into PostgresError let postgres_error = serde_json::from_slice(&bytes) .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; @@ -1175,7 +1269,7 @@ async fn handle_rest_inner( .boxed(); // build the response - let mut response = Response::builder() + response = response .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) .header(CONTENT_TYPE, http_content_type);