From dce4ed9f1d5643cc180adfdd85f408c0d93515ba Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 24 Jan 2025 17:28:04 +0800 Subject: [PATCH] feat: add CORS headers for http interfaces (#5447) * feat: add cors headers for http server * test: add cors test * test: add preflight test --- Cargo.lock | 28 ++++++++++++++--- src/servers/Cargo.toml | 2 +- src/servers/src/http.rs | 14 +++++++++ src/servers/src/http/test_helpers.rs | 13 +++++++- tests-integration/Cargo.toml | 1 + tests-integration/tests/http.rs | 45 ++++++++++++++++++++++++++++ 6 files changed, 97 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ea055ec03d..af8162528a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5894,7 +5894,7 @@ dependencies = [ "tokio", "tokio-util", "tower 0.4.13", - "tower-http", + "tower-http 0.5.2", "tracing", ] @@ -10439,7 +10439,7 @@ dependencies = [ "tonic 0.12.3", "tonic-reflection", "tower 0.5.2", - "tower-http", + "tower-http 0.6.2", "urlencoding", "uuid", "zstd 0.13.2", @@ -11760,6 +11760,7 @@ dependencies = [ "futures", "futures-util", "hex", + "http 1.1.0", "hyper-util", "itertools 0.10.5", "log-query", @@ -12363,10 +12364,29 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ - "async-compression 0.4.13", "base64 0.21.7", "bitflags 2.6.0", "bytes", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" +dependencies = [ + "async-compression 0.4.13", + "base64 0.22.1", + "bitflags 2.6.0", + "bytes", "futures-core", "futures-util", "http 1.1.0", @@ -12381,7 +12401,7 @@ dependencies = [ "pin-project-lite", "tokio", "tokio-util", - "tower 0.4.13", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 650eb94df3..74b3781f64 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -113,7 +113,7 @@ tokio-util.workspace = true tonic.workspace = true tonic-reflection = "0.12" tower = { workspace = true, features = ["full"] } -tower-http = { version = "0.5", features = ["full"] } +tower-http = { version = "0.6", features = ["full"] } urlencoding = "2.1" uuid.workspace = true zstd.workspace = true diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 8ceccda226..e446e1a455 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -36,12 +36,14 @@ use datatypes::schema::SchemaRef; use datatypes::value::transform_value_ref_to_json_value; use event::{LogState, LogValidatorRef}; use futures::FutureExt; +use http::Method; use serde::{Deserialize, Serialize}; use serde_json::Value; use snafu::{ensure, ResultExt}; use tokio::sync::oneshot::{self, Sender}; use tokio::sync::Mutex; use tower::ServiceBuilder; +use tower_http::cors::{Any, CorsLayer}; use tower_http::decompression::RequestDecompressionLayer; use tower_http::trace::TraceLayer; @@ -737,6 +739,18 @@ impl HttpServer { // disable on failure tracing. because printing out isn't very helpful, // and we have impl IntoResponse for Error. It will print out more detailed error messages .layer(TraceLayer::new_for_http().on_failure(())) + .layer( + CorsLayer::new() + .allow_methods([ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::HEAD, + ]) + .allow_origin(Any) + .allow_headers(Any), + ) .option_layer(timeout_layer) .option_layer(body_limit_layer) // auth layer diff --git a/src/servers/src/http/test_helpers.rs b/src/servers/src/http/test_helpers.rs index ca976220cf..09296dc10d 100644 --- a/src/servers/src/http/test_helpers.rs +++ b/src/servers/src/http/test_helpers.rs @@ -37,7 +37,7 @@ use axum::Router; use bytes::Bytes; use common_telemetry::info; use http::header::{HeaderName, HeaderValue}; -use http::StatusCode; +use http::{Method, StatusCode}; use tokio::net::TcpListener; /// Test client to Axum servers. @@ -128,6 +128,17 @@ impl TestClient { builder: self.client.delete(format!("http://{}{}", self.addr, url)), } } + + /// Options preflight request + pub fn options(&self, url: &str) -> RequestBuilder { + common_telemetry::info!("OPTIONS {} {}", self.addr, url); + + RequestBuilder { + builder: self + .client + .request(Method::OPTIONS, format!("http://{}{}", self.addr, url)), + } + } } /// Builder for test requests. diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index e6b6a8e166..3ededb0769 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -87,6 +87,7 @@ zstd.workspace = true datafusion.workspace = true datafusion-expr.workspace = true hex.workspace = true +http.workspace = true itertools.workspace = true opentelemetry-proto.workspace = true partition.workspace = true diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 714f19a972..e4dc1efd50 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -151,6 +151,51 @@ pub async fn test_http_auth(store_type: StorageType) { guard.remove_all().await; } +#[tokio::test] +pub async fn test_cors() { + let (app, mut guard) = setup_test_http_app_with_frontend(StorageType::File, "test_cors").await; + let client = TestClient::new(app).await; + + let res = client.get("/health").send().await; + + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers() + .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN) + .expect("expect cors header origin"), + "*" + ); + + let res = client + .options("/health") + .header("Access-Control-Request-Headers", "x-greptime-auth") + .header("Access-Control-Request-Method", "DELETE") + .header("Origin", "https://example.com") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.headers() + .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN) + .expect("expect cors header origin"), + "*" + ); + assert_eq!( + res.headers() + .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS) + .expect("expect cors header headers"), + "*" + ); + assert_eq!( + res.headers() + .get(http::header::ACCESS_CONTROL_ALLOW_METHODS) + .expect("expect cors header methods"), + "GET,POST,PUT,DELETE,HEAD" + ); + + guard.remove_all().await; +} + pub async fn test_sql_api(store_type: StorageType) { let (app, mut guard) = setup_test_http_app_with_frontend(store_type, "sql_api").await; let client = TestClient::new(app).await;