From fc6d73b06bf4ed8f94fc37967b41e962148dad8b Mon Sep 17 00:00:00 2001 From: dennis zhuang Date: Fri, 4 Nov 2022 14:09:07 +0800 Subject: [PATCH] feat: improve /scripts API (#390) * feat: improve /scripts API * chore: json_err macro * chore: json_err macro and refactor code * fix: test --- src/datanode/src/tests/http_test.rs | 13 ++--- src/servers/src/http/handler.rs | 61 +++++++++++++-------- src/servers/tests/http/http_handler_test.rs | 39 ++++++++----- 3 files changed, 67 insertions(+), 46 deletions(-) diff --git a/src/datanode/src/tests/http_test.rs b/src/datanode/src/tests/http_test.rs index 25a83d3a28..75b5604cfb 100644 --- a/src/datanode/src/tests/http_test.rs +++ b/src/datanode/src/tests/http_test.rs @@ -5,7 +5,6 @@ use axum::http::StatusCode; use axum::Router; use axum_test_helper::TestClient; use datatypes::prelude::ConcreteDataType; -use servers::http::handler::ScriptExecution; use servers::http::HttpServer; use servers::server::Server; use test_util::TestGuard; @@ -111,16 +110,14 @@ async fn test_scripts_api() { let (app, _guard) = make_test_app("scripts_api").await; let client = TestClient::new(app); let res = client - .post("/v1/scripts") - .json(&ScriptExecution { - name: "test".to_string(), - script: r#" + .post("/v1/scripts?name=test") + .body( + r#" @copr(sql='select number from numbers limit 10', args=['number'], returns=['n']) def test(n): return n + 1; -"# - .to_string(), - }) +"#, + ) .send() .await; assert_eq!(res.status(), StatusCode::OK); diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 7ddb3bc4c1..884b823647 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use aide::transform::TransformOperation; -use axum::extract::{Json, Query, State}; +use axum::extract::{Json, Query, RawBody, State}; use common_error::prelude::ErrorExt; use common_error::status_code::StatusCode; use common_telemetry::metric; @@ -47,54 +47,67 @@ pub async fn metrics(Query(_params): Query>) -> String { } } -#[derive(Debug, Deserialize, Serialize, JsonSchema)] -pub struct ScriptExecution { - pub name: String, - pub script: String, +macro_rules! json_err { + ($e: expr) => {{ + return Json(JsonResponse::with_error( + format!("Invalid argument: {}", $e), + common_error::status_code::StatusCode::InvalidArguments, + )); + }}; + + ($msg: expr, $code: expr) => {{ + return Json(JsonResponse::with_error($msg.to_string(), $code)); + }}; +} + +macro_rules! unwrap_or_json_err { + ($result: expr) => { + match $result { + Ok(result) => result, + Err(e) => json_err!(e), + } + }; } /// Handler to insert and compile script #[axum_macros::debug_handler] pub async fn scripts( State(query_handler): State, - Json(payload): Json, + Query(params): Query, + RawBody(body): RawBody, ) -> Json { - if payload.name.is_empty() || payload.script.is_empty() { - return Json(JsonResponse::with_error( - "Invalid name or script".to_string(), - StatusCode::InvalidArguments, - )); - } + let name = params.name.as_ref(); - let body = match query_handler - .insert_script(&payload.name, &payload.script) - .await - { + if name.is_none() || name.unwrap().is_empty() { + json_err!("invalid name"); + } + let bytes = unwrap_or_json_err!(hyper::body::to_bytes(body).await); + + let script = unwrap_or_json_err!(String::from_utf8(bytes.to_vec())); + + let body = match query_handler.insert_script(name.unwrap(), &script).await { Ok(()) => JsonResponse::with_output(None), - Err(e) => JsonResponse::with_error(format!("Insert script error: {}", e), e.status_code()), + Err(e) => json_err!(format!("Insert script error: {}", e), e.status_code()), }; Json(body) } #[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct RunScriptQuery { - name: Option, +pub struct ScriptQuery { + pub name: Option, } /// Handler to execute script #[axum_macros::debug_handler] pub async fn run_script( State(query_handler): State, - Query(params): Query, + Query(params): Query, ) -> Json { let name = params.name.as_ref(); if name.is_none() || name.unwrap().is_empty() { - return Json(JsonResponse::with_error( - "Invalid name".to_string(), - StatusCode::InvalidArguments, - )); + json_err!("invalid name"); } let output = query_handler.execute_script(name.unwrap()).await; diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index c504b72d53..10c0fa9fbf 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; -use axum::extract::{Json, Query, State}; +use axum::body::Body; +use axum::extract::{Json, Query, RawBody, State}; use common_telemetry::metric; use metrics::counter; use servers::http::handler as http_handler; -use servers::http::handler::ScriptExecution; use servers::http::JsonOutput; use table::test_util::MemTable; @@ -58,27 +58,38 @@ async fn test_metrics() { async fn test_scripts() { common_telemetry::init_default_ut_logging(); - let exec = create_script_payload(); - let query_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let script = r#" +@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n']) +def test(n): + return n; +"# + .to_string(); - let Json(json) = http_handler::scripts(State(query_handler), exec).await; + let query_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let body = RawBody(Body::from(script.clone())); + let invalid_query = create_invalid_script_query(); + let Json(json) = http_handler::scripts(State(query_handler.clone()), invalid_query, body).await; + assert!(!json.success(), "{:?}", json); + assert_eq!(json.error().unwrap(), "Invalid argument: invalid name"); + + let body = RawBody(Body::from(script)); + let exec = create_script_query(); + let Json(json) = http_handler::scripts(State(query_handler), exec, body).await; assert!(json.success(), "{:?}", json); assert!(json.error().is_none()); assert!(json.output().is_none()); } -fn create_script_payload() -> Json { - Json(ScriptExecution { - name: "test".to_string(), - script: r#" -@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n']) -def test(n): - return n; -"# - .to_string(), +fn create_script_query() -> Query { + Query(http_handler::ScriptQuery { + name: Some("test".to_string()), }) } +fn create_invalid_script_query() -> Query { + Query(http_handler::ScriptQuery { name: None }) +} + fn create_query() -> Query { Query(http_handler::SqlQuery { sql: Some("select sum(uint32s) from numbers limit 20".to_string()),