mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-29 11:20:38 +00:00
feat: supports passing user params into coprocessor (#962)
* feat: make args in coprocessor optional * feat: supports kwargs for coprocessor as params passed by the users * feat: supports params for /run-script * fix: we should rewrite the coprocessor by removing kwargs * fix: remove println * fix: compile error after rebasing * fix: improve http_handler_test * test: http scripts api with user params * refactor: tweak all to_owned
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use axum::extract::{Json, Query, RawBody, State};
|
||||
@@ -81,10 +81,12 @@ pub async fn scripts(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct ScriptQuery {
|
||||
pub db: Option<String>,
|
||||
pub name: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Handler to execute script
|
||||
@@ -110,7 +112,7 @@ pub async fn run_script(
|
||||
// TODO(sunng87): query_context and db name resolution
|
||||
|
||||
let output = script_handler
|
||||
.execute_script(schema.unwrap(), name.unwrap())
|
||||
.execute_script(schema.unwrap(), name.unwrap(), params.params)
|
||||
.await;
|
||||
let resp = JsonResponse::from_output(vec![output]).await;
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
pub mod grpc;
|
||||
pub mod sql;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::prometheus::remote::{ReadRequest, WriteRequest};
|
||||
@@ -45,7 +46,12 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
|
||||
#[async_trait]
|
||||
pub trait ScriptHandler {
|
||||
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>;
|
||||
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output>;
|
||||
async fn execute_script(
|
||||
&self,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -23,7 +23,10 @@ use servers::http::{handler as http_handler, script as script_handler, ApiState,
|
||||
use session::context::UserInfo;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::{create_testing_script_handler, create_testing_sql_query_handler};
|
||||
use crate::{
|
||||
create_testing_script_handler, create_testing_sql_query_handler, ScriptHandlerRef,
|
||||
ServerSqlQueryHandlerRef,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sql_not_provided() {
|
||||
@@ -68,6 +71,25 @@ async fn test_sql_output_rows() {
|
||||
match &json.output().expect("assertion failed")[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
assert_eq!(1, records.num_rows());
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "SUM(numbers.uint32s)",
|
||||
"data_type": "UInt64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
4950
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
@@ -95,6 +117,25 @@ async fn test_sql_form() {
|
||||
match &json.output().expect("assertion failed")[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
assert_eq!(1, records.num_rows());
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "SUM(numbers.uint32s)",
|
||||
"data_type": "UInt64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
4950
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
@@ -110,18 +151,11 @@ async fn test_metrics() {
|
||||
assert!(text.contains("test_metrics counter"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n'])
|
||||
def test(n):
|
||||
return n;
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
async fn insert_script(
|
||||
script: String,
|
||||
script_handler: ScriptHandlerRef,
|
||||
sql_handler: ServerSqlQueryHandlerRef,
|
||||
) {
|
||||
let body = RawBody(Body::from(script.clone()));
|
||||
let invalid_query = create_invalid_script_query();
|
||||
let Json(json) = script_handler::scripts(
|
||||
@@ -136,12 +170,13 @@ def test(n):
|
||||
assert!(!json.success(), "{json:?}");
|
||||
assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema");
|
||||
|
||||
let body = RawBody(Body::from(script));
|
||||
let body = RawBody(Body::from(script.clone()));
|
||||
let exec = create_script_query();
|
||||
// Insert the script
|
||||
let Json(json) = script_handler::scripts(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
sql_handler: sql_handler.clone(),
|
||||
script_handler: Some(script_handler.clone()),
|
||||
}),
|
||||
exec,
|
||||
body,
|
||||
@@ -152,10 +187,144 @@ def test(n):
|
||||
assert!(json.output().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
|
||||
def test(n) -> vector[i64]:
|
||||
return n;
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
|
||||
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
|
||||
// Run the script
|
||||
let exec = create_script_query();
|
||||
let Json(json) = script_handler::run_script(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
}),
|
||||
exec,
|
||||
)
|
||||
.await;
|
||||
assert!(json.success(), "{json:?}");
|
||||
assert!(json.error().is_none());
|
||||
|
||||
match &json.output().unwrap()[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(5, records.num_rows());
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "n",
|
||||
"data_type": "Int64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
0
|
||||
],
|
||||
[
|
||||
1
|
||||
],
|
||||
[
|
||||
2
|
||||
],
|
||||
[
|
||||
3
|
||||
],
|
||||
[
|
||||
4
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts_with_params() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
|
||||
def test(n, **params) -> vector[i64]:
|
||||
return n + int(params['a'])
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
|
||||
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
|
||||
// Run the script
|
||||
let mut exec = create_script_query();
|
||||
exec.0.params.insert("a".to_string(), "42".to_string());
|
||||
let Json(json) = script_handler::run_script(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
}),
|
||||
exec,
|
||||
)
|
||||
.await;
|
||||
assert!(json.success(), "{json:?}");
|
||||
assert!(json.error().is_none());
|
||||
|
||||
match &json.output().unwrap()[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(5, records.num_rows());
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "n",
|
||||
"data_type": "Int64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
42
|
||||
],
|
||||
[
|
||||
43
|
||||
],
|
||||
[
|
||||
44
|
||||
],
|
||||
[
|
||||
45
|
||||
],
|
||||
[
|
||||
46
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_script_query() -> Query<script_handler::ScriptQuery> {
|
||||
Query(script_handler::ScriptQuery {
|
||||
db: Some("test".to_string()),
|
||||
name: Some("test".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -163,6 +332,7 @@ fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
|
||||
Query(script_handler::ScriptQuery {
|
||||
db: None,
|
||||
name: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -120,12 +120,20 @@ impl ScriptHandler for DummyInstance {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output> {
|
||||
async fn execute_script(
|
||||
&self,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output> {
|
||||
let key = format!("{schema}_{name}");
|
||||
|
||||
let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();
|
||||
|
||||
Ok(py_script.execute(EvalContext::default()).await.unwrap())
|
||||
Ok(py_script
|
||||
.execute(params, EvalContext::default())
|
||||
.await
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user