feat: supports passing user params into coprocessor (#962)

* feat: make args in coprocessor optional

* feat: supports kwargs for coprocessor as params passed by the users

* feat: supports params for /run-script

* fix: we should rewrite the coprocessor by removing kwargs

* fix: remove println

* fix: compile error after rebasing

* fix: improve http_handler_test

* test: http scripts api with user params

* refactor: tweak all to_owned
This commit is contained in:
dennis zhuang
2023-02-16 16:11:26 +08:00
committed by GitHub
parent ddbc97befb
commit 5ec1a7027b
19 changed files with 564 additions and 142 deletions

View File

@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::time::Instant;
use axum::extract::{Json, Query, RawBody, State};
@@ -81,10 +81,12 @@ pub async fn scripts(
}
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Serialize, Deserialize, JsonSchema, Default)]
pub struct ScriptQuery {
pub db: Option<String>,
pub name: Option<String>,
#[serde(flatten)]
pub params: HashMap<String, String>,
}
/// Handler to execute script
@@ -110,7 +112,7 @@ pub async fn run_script(
// TODO(sunng87): query_context and db name resolution
let output = script_handler
.execute_script(schema.unwrap(), name.unwrap())
.execute_script(schema.unwrap(), name.unwrap(), params.params)
.await;
let resp = JsonResponse::from_output(vec![output]).await;

View File

@@ -25,6 +25,7 @@
pub mod grpc;
pub mod sql;
use std::collections::HashMap;
use std::sync::Arc;
use api::prometheus::remote::{ReadRequest, WriteRequest};
@@ -45,7 +46,12 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
#[async_trait]
pub trait ScriptHandler {
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>;
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output>;
async fn execute_script(
&self,
schema: &str,
name: &str,
params: HashMap<String, String>,
) -> Result<Output>;
}
#[async_trait]

View File

@@ -23,7 +23,10 @@ use servers::http::{handler as http_handler, script as script_handler, ApiState,
use session::context::UserInfo;
use table::test_util::MemTable;
use crate::{create_testing_script_handler, create_testing_sql_query_handler};
use crate::{
create_testing_script_handler, create_testing_sql_query_handler, ScriptHandlerRef,
ServerSqlQueryHandlerRef,
};
#[tokio::test]
async fn test_sql_not_provided() {
@@ -68,6 +71,25 @@ async fn test_sql_output_rows() {
match &json.output().expect("assertion failed")[0] {
JsonOutput::Records(records) => {
assert_eq!(1, records.num_rows());
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "SUM(numbers.uint32s)",
"data_type": "UInt64"
}
]
},
"rows": [
[
4950
]
]
}"#
);
}
_ => unreachable!(),
}
@@ -95,6 +117,25 @@ async fn test_sql_form() {
match &json.output().expect("assertion failed")[0] {
JsonOutput::Records(records) => {
assert_eq!(1, records.num_rows());
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "SUM(numbers.uint32s)",
"data_type": "UInt64"
}
]
},
"rows": [
[
4950
]
]
}"#
);
}
_ => unreachable!(),
}
@@ -110,18 +151,11 @@ async fn test_metrics() {
assert!(text.contains("test_metrics counter"));
}
#[tokio::test]
async fn test_scripts() {
common_telemetry::init_default_ut_logging();
let script = r#"
@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n'])
def test(n):
return n;
"#
.to_string();
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
async fn insert_script(
script: String,
script_handler: ScriptHandlerRef,
sql_handler: ServerSqlQueryHandlerRef,
) {
let body = RawBody(Body::from(script.clone()));
let invalid_query = create_invalid_script_query();
let Json(json) = script_handler::scripts(
@@ -136,12 +170,13 @@ def test(n):
assert!(!json.success(), "{json:?}");
assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema");
let body = RawBody(Body::from(script));
let body = RawBody(Body::from(script.clone()));
let exec = create_script_query();
// Insert the script
let Json(json) = script_handler::scripts(
State(ApiState {
sql_handler,
script_handler: Some(script_handler),
sql_handler: sql_handler.clone(),
script_handler: Some(script_handler.clone()),
}),
exec,
body,
@@ -152,10 +187,144 @@ def test(n):
assert!(json.output().is_none());
}
#[tokio::test]
async fn test_scripts() {
common_telemetry::init_default_ut_logging();
let script = r#"
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
def test(n) -> vector[i64]:
return n;
"#
.to_string();
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
// Run the script
let exec = create_script_query();
let Json(json) = script_handler::run_script(
State(ApiState {
sql_handler,
script_handler: Some(script_handler),
}),
exec,
)
.await;
assert!(json.success(), "{json:?}");
assert!(json.error().is_none());
match &json.output().unwrap()[0] {
JsonOutput::Records(records) => {
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(5, records.num_rows());
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "n",
"data_type": "Int64"
}
]
},
"rows": [
[
0
],
[
1
],
[
2
],
[
3
],
[
4
]
]
}"#
);
}
_ => unreachable!(),
}
}
#[tokio::test]
async fn test_scripts_with_params() {
common_telemetry::init_default_ut_logging();
let script = r#"
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
def test(n, **params) -> vector[i64]:
return n + int(params['a'])
"#
.to_string();
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
// Run the script
let mut exec = create_script_query();
exec.0.params.insert("a".to_string(), "42".to_string());
let Json(json) = script_handler::run_script(
State(ApiState {
sql_handler,
script_handler: Some(script_handler),
}),
exec,
)
.await;
assert!(json.success(), "{json:?}");
assert!(json.error().is_none());
match &json.output().unwrap()[0] {
JsonOutput::Records(records) => {
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(5, records.num_rows());
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "n",
"data_type": "Int64"
}
]
},
"rows": [
[
42
],
[
43
],
[
44
],
[
45
],
[
46
]
]
}"#
);
}
_ => unreachable!(),
}
}
fn create_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery {
db: Some("test".to_string()),
name: Some("test".to_string()),
..Default::default()
})
}
@@ -163,6 +332,7 @@ fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery {
db: None,
name: None,
..Default::default()
})
}

View File

@@ -120,12 +120,20 @@ impl ScriptHandler for DummyInstance {
Ok(())
}
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output> {
async fn execute_script(
&self,
schema: &str,
name: &str,
params: HashMap<String, String>,
) -> Result<Output> {
let key = format!("{schema}_{name}");
let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();
Ok(py_script.execute(EvalContext::default()).await.unwrap())
Ok(py_script
.execute(params, EvalContext::default())
.await
.unwrap())
}
}