[proxy] handle options request in rest broker (cors headers) (#12744)

## Problem
rest broker needs to respond with the correct cors headers for the api
to be usable from other domains

## Summary of changes
added a code path in rest broker to handle the OPTIONS requests

---------

Co-authored-by: Ruslan Talpa <ruslan.talpa@databricks.com>
This commit is contained in:
Ruslan Talpa
2025-07-31 16:05:09 +03:00
committed by GitHub
parent 312a74f11f
commit d96cea1917

View File

@@ -5,12 +5,17 @@ use std::sync::Arc;
use bytes::Bytes;
use http::Method;
use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST};
use http::header::{
ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW,
AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN,
};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_body_util::{BodyExt, Empty, Full};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::http::response::Builder;
use hyper::http::{HeaderMap, HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode};
use indexmap::IndexMap;
use moka::sync::Cache;
@@ -67,6 +72,15 @@ use crate::util::deserialize_json_string;
static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#;
const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL;
const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*");
// CORS headers values
const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue =
HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS");
const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400");
const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static(
"Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit",
);
const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization");
// A wrapper around the DbSchema that allows for self-referencing
#[self_referencing]
@@ -137,6 +151,8 @@ pub struct ApiConfig {
pub role_claim_key: String,
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
pub db_extra_search_path: Option<Vec<String>>,
#[serde(default, deserialize_with = "deserialize_comma_separated_option")]
pub server_cors_allowed_origins: Option<Vec<String>>,
}
// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint
@@ -165,7 +181,13 @@ impl DbSchemaCache {
}
}
pub async fn get_cached_or_remote(
pub fn get_cached(
&self,
endpoint_id: &EndpointCacheKey,
) -> Option<Arc<(ApiConfig, DbSchemaOwned)>> {
count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id))
}
pub async fn get_remote(
&self,
endpoint_id: &EndpointCacheKey,
auth_header: &HeaderValue,
@@ -174,47 +196,42 @@ impl DbSchemaCache {
ctx: &RequestContext,
config: &'static ProxyConfig,
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id));
match cache_result {
Some(v) => Ok(v),
None => {
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
let remote_value = self
.get_remote(auth_header, connection_string, client, ctx, config)
.await;
let (api_config, schema_owned) = match remote_value {
Ok((api_config, schema_owned)) => (api_config, schema_owned),
Err(e @ RestError::SchemaTooLarge) => {
// for the case where the schema is too large, we cache an empty dummy value
// all the other requests will fail without triggering the introspection query
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
.map_err(|e| JsonDeserialize { source: e })?;
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
let remote_value = self
.internal_get_remote(auth_header, connection_string, client, ctx, config)
.await;
let (api_config, schema_owned) = match remote_value {
Ok((api_config, schema_owned)) => (api_config, schema_owned),
Err(e @ RestError::SchemaTooLarge) => {
// for the case where the schema is too large, we cache an empty dummy value
// all the other requests will fail without triggering the introspection query
let schema_owned = serde_json::from_str::<DbSchemaOwned>(EMPTY_JSON_SCHEMA)
.map_err(|e| JsonDeserialize { source: e })?;
let api_config = ApiConfig {
db_schemas: vec![],
db_anon_role: None,
db_max_rows: None,
db_allowed_select_functions: vec![],
role_claim_key: String::new(),
db_extra_search_path: None,
};
let value = Arc::new((api_config, schema_owned));
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value);
return Err(e);
}
Err(e) => {
return Err(e);
}
let api_config = ApiConfig {
db_schemas: vec![],
db_anon_role: None,
db_max_rows: None,
db_allowed_select_functions: vec![],
role_claim_key: String::new(),
db_extra_search_path: None,
server_cors_allowed_origins: None,
};
let value = Arc::new((api_config, schema_owned));
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value.clone());
Ok(value)
self.0.insert(endpoint_id.clone(), value);
return Err(e);
}
}
Err(e) => {
return Err(e);
}
};
let value = Arc::new((api_config, schema_owned));
count_cache_insert(CacheKind::Schema);
self.0.insert(endpoint_id.clone(), value.clone());
Ok(value)
}
pub async fn get_remote(
async fn internal_get_remote(
&self,
auth_header: &HeaderValue,
connection_string: &str,
@@ -531,7 +548,7 @@ pub(crate) async fn handle(
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
let response = match result {
Ok(r) => {
ctx.set_success();
@@ -640,9 +657,6 @@ pub(crate) async fn handle(
}
};
response
.headers_mut()
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
Ok(response)
}
@@ -722,6 +736,37 @@ async fn handle_inner(
}
}
fn apply_common_cors_headers(
response: &mut Builder,
request_headers: &HeaderMap,
allowed_origins: Option<&Vec<String>>,
) {
let request_origin = request_headers
.get(ORIGIN)
.map(|v| v.to_str().unwrap_or(""));
let response_allow_origin = match (request_origin, allowed_origins) {
(Some(or), Some(allowed_origins)) => {
if allowed_origins.iter().any(|o| o == or) {
Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS))
} else {
None
}
}
(Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS),
_ => None,
};
if let Some(h) = response.headers_mut() {
h.insert(
ACCESS_CONTROL_EXPOSE_HEADERS,
ACCESS_CONTROL_EXPOSE_HEADERS_VALUE,
);
if let Some(origin) = response_allow_origin {
h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_rest_inner(
config: &'static ProxyConfig,
@@ -733,12 +778,6 @@ async fn handle_rest_inner(
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, RestError> {
// validate the jwt token
let jwt_parsed = backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?;
let db_schema_cache =
config
.rest_config
@@ -754,28 +793,83 @@ async fn handle_rest_inner(
message: "Failed to get endpoint cache key".to_string(),
}))?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let (parts, originial_body) = request.into_parts();
// try and get the cached entry for this endpoint
// it contains the api config and the introspected db schema
let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key);
let allowed_origins = cached_entry
.as_ref()
.and_then(|arc| arc.0.server_cors_allowed_origins.as_ref());
let mut response = Response::builder();
apply_common_cors_headers(&mut response, &parts.headers, allowed_origins);
// handle the OPTIONS request
if parts.method == Method::OPTIONS {
let allowed_headers = parts
.headers
.get(ACCESS_CONTROL_REQUEST_HEADERS)
.and_then(|a| a.to_str().ok())
.filter(|v| !v.is_empty())
.map_or_else(
|| "Authorization".to_string(),
|v| format!("{v}, Authorization"),
);
return response
.status(StatusCode::OK)
.header(
ACCESS_CONTROL_ALLOW_METHODS,
ACCESS_CONTROL_ALLOW_METHODS_VALUE,
)
.header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE)
.header(
ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_str(&allowed_headers)
.unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE),
)
.header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE)
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| {
RestError::SubzeroCore(InternalError {
message: e.to_string(),
})
});
}
// validate the jwt token
let jwt_parsed = backend
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?;
let auth_header = parts
.headers
.get(AUTHORIZATION)
.ok_or(RestError::SubzeroCore(InternalError {
message: "Authorization header is required".to_string(),
}))?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let entry = db_schema_cache
.get_cached_or_remote(
&endpoint_cache_key,
auth_header,
connection_string,
&mut client,
ctx,
config,
)
.await?;
let entry = match cached_entry {
Some(e) => e,
None => {
// if not cached, get the remote entry (will run the introspection query)
db_schema_cache
.get_remote(
&endpoint_cache_key,
auth_header,
connection_string,
&mut client,
ctx,
config,
)
.await?
}
};
let (api_config, db_schema_owned) = entry.as_ref();
let db_schema = db_schema_owned.borrow_schema();
let db_schemas = &api_config.db_schemas; // list of schemas available for the api
@@ -999,8 +1093,8 @@ async fn handle_rest_inner(
let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly?
// send the request to the local proxy
let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
let (parts, body) = response.into_parts();
let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?;
let (response_parts, body) = proxy_response.into_parts();
let max_response = config.http_config.max_response_size_bytes;
let bytes = read_body_with_limit(body, max_response)
@@ -1009,7 +1103,7 @@ async fn handle_rest_inner(
// if the response status is greater than 399, then it is an error
// FIXME: check if there are other error codes or shapes of the response
if parts.status.as_u16() > 399 {
if response_parts.status.as_u16() > 399 {
// turn this postgres error from the json into PostgresError
let postgres_error = serde_json::from_slice(&bytes)
.map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?;
@@ -1175,7 +1269,7 @@ async fn handle_rest_inner(
.boxed();
// build the response
let mut response = Response::builder()
response = response
.status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
.header(CONTENT_TYPE, http_content_type);