From afac885c1089df62701625fd49bc05c0dae299f3 Mon Sep 17 00:00:00 2001 From: ShenJunkun <66257605+ShenJunkun@users.noreply.github.com> Date: Tue, 7 Feb 2023 11:07:32 +0800 Subject: [PATCH] refactor: add schema column to the scripts table (#868) --- src/datanode/src/instance/script.rs | 15 ++++++++--- src/datanode/src/script.rs | 17 ++++++++++--- src/frontend/src/instance.rs | 13 +++++++--- src/script/src/manager.rs | 25 ++++++++++++------ src/script/src/table.rs | 28 +++++++++++++++------ src/servers/src/http/script.rs | 23 +++++++++++++++-- src/servers/src/query_handler.rs | 4 +-- src/servers/tests/http/http_handler_test.rs | 8 ++++-- src/servers/tests/mod.rs | 10 +++++--- src/servers/tests/py_script/mod.rs | 4 ++- tests-integration/tests/http.rs | 7 ++++-- 11 files changed, 115 insertions(+), 39 deletions(-) diff --git a/src/datanode/src/instance/script.rs b/src/datanode/src/instance/script.rs index 140f1bcca5..8dd878546b 100644 --- a/src/datanode/src/instance/script.rs +++ b/src/datanode/src/instance/script.rs @@ -22,13 +22,20 @@ use crate::metric; #[async_trait] impl ScriptHandler for Instance { - async fn insert_script(&self, name: &str, script: &str) -> servers::error::Result<()> { + async fn insert_script( + &self, + schema: &str, + name: &str, + script: &str, + ) -> servers::error::Result<()> { let _timer = timer!(metric::METRIC_HANDLE_SCRIPTS_ELAPSED); - self.script_executor.insert_script(name, script).await + self.script_executor + .insert_script(schema, name, script) + .await } - async fn execute_script(&self, name: &str) -> servers::error::Result { + async fn execute_script(&self, schema: &str, name: &str) -> servers::error::Result { let _timer = timer!(metric::METRIC_RUN_SCRIPT_ELAPSED); - self.script_executor.execute_script(name).await + self.script_executor.execute_script(schema, name).await } } diff --git a/src/datanode/src/script.rs b/src/datanode/src/script.rs index 1a284da16d..b7cc622c95 100644 --- a/src/datanode/src/script.rs +++ b/src/datanode/src/script.rs @@ -71,10 +71,15 @@ mod python { }) } - pub async fn insert_script(&self, name: &str, script: &str) -> servers::error::Result<()> { + pub async fn insert_script( + &self, + schema: &str, + name: &str, + script: &str, + ) -> servers::error::Result<()> { let _s = self .script_manager - .insert_and_compile(name, script) + .insert_and_compile(schema, name, script) .await .map_err(|e| { error!(e; "Instance failed to insert script"); @@ -85,9 +90,13 @@ mod python { Ok(()) } - pub async fn execute_script(&self, name: &str) -> servers::error::Result { + pub async fn execute_script( + &self, + schema: &str, + name: &str, + ) -> servers::error::Result { self.script_manager - .execute(name) + .execute(schema, name) .await .map_err(|e| { error!(e; "Instance failed to execute script"); diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 906f8e3ab8..642eba1f3e 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -501,9 +501,14 @@ impl SqlQueryHandler for Instance { #[async_trait] impl ScriptHandler for Instance { - async fn insert_script(&self, name: &str, script: &str) -> server_error::Result<()> { + async fn insert_script( + &self, + schema: &str, + name: &str, + script: &str, + ) -> server_error::Result<()> { if let Some(handler) = &self.script_handler { - handler.insert_script(name, script).await + handler.insert_script(schema, name, script).await } else { server_error::NotSupportedSnafu { feat: "Script execution in Frontend", @@ -512,9 +517,9 @@ impl ScriptHandler for Instance { } } - async fn execute_script(&self, script: &str) -> server_error::Result { + async fn execute_script(&self, schema: &str, script: &str) -> server_error::Result { if let Some(handler) = &self.script_handler { - handler.execute_script(script).await + handler.execute_script(schema, script).await } else { server_error::NotSupportedSnafu { feat: "Script execution in Frontend", diff --git a/src/script/src/manager.rs b/src/script/src/manager.rs index 59c6ea62cb..11f3a42ccd 100644 --- a/src/script/src/manager.rs +++ b/src/script/src/manager.rs @@ -65,20 +65,25 @@ impl ScriptManager { Ok(script) } - pub async fn insert_and_compile(&self, name: &str, script: &str) -> Result> { + pub async fn insert_and_compile( + &self, + schema: &str, + name: &str, + script: &str, + ) -> Result> { let compiled_script = self.compile(name, script).await?; - self.table.insert(name, script).await?; + self.table.insert(schema, name, script).await?; Ok(compiled_script) } - pub async fn execute(&self, name: &str) -> Result { + pub async fn execute(&self, schema: &str, name: &str) -> Result { let script = { let s = self.compiled.read().unwrap().get(name).cloned(); if s.is_some() { s } else { - self.try_find_script_and_compile(name).await? + self.try_find_script_and_compile(schema, name).await? } }; @@ -90,8 +95,12 @@ impl ScriptManager { .context(ExecutePythonSnafu { name }) } - async fn try_find_script_and_compile(&self, name: &str) -> Result>> { - let script = self.table.find_script_by_name(name).await?; + async fn try_find_script_and_compile( + &self, + schema: &str, + name: &str, + ) -> Result>> { + let script = self.table.find_script_by_name(schema, name).await?; Ok(Some(self.compile(name, &script).await?)) } @@ -149,9 +158,11 @@ mod tests { .unwrap(); catalog_manager.start().await.unwrap(); + let schema = "schema"; let name = "test"; mgr.table .insert( + schema, name, r#" @copr(sql='select number from numbers limit 10', args=['number'], returns=['n']) @@ -168,7 +179,7 @@ def test(n): } // try to find and compile - let script = mgr.try_find_script_and_compile(name).await.unwrap(); + let script = mgr.try_find_script_and_compile(schema, name).await.unwrap(); assert!(script.is_some()); { diff --git a/src/script/src/table.rs b/src/script/src/table.rs index e885a9e598..7fe267eeab 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -61,8 +61,8 @@ impl ScriptsTable { desc: Some("Scripts table".to_string()), schema, region_numbers: vec![0], - // name as primary key - primary_key_indices: vec![0], + //schema and name as primary key + primary_key_indices: vec![0, 1], create_if_not_exists: true, table_options: HashMap::default(), }; @@ -86,8 +86,12 @@ impl ScriptsTable { }) } - pub async fn insert(&self, name: &str, script: &str) -> Result<()> { - let mut columns_values: HashMap = HashMap::with_capacity(7); + pub async fn insert(&self, schema: &str, name: &str, script: &str) -> Result<()> { + let mut columns_values: HashMap = HashMap::with_capacity(8); + columns_values.insert( + "schema".to_string(), + Arc::new(StringVector::from(vec![schema])) as _, + ); columns_values.insert( "name".to_string(), Arc::new(StringVector::from(vec![name])) as _, @@ -115,7 +119,6 @@ impl ScriptsTable { "gmt_modified".to_string(), Arc::new(TimestampMillisecondVector::from_slice(&[now])) as _, ); - let table = self .catalog_manager .table( @@ -142,12 +145,18 @@ impl ScriptsTable { Ok(()) } - pub async fn find_script_by_name(&self, name: &str) -> Result { + pub async fn find_script_by_name(&self, schema: &str, name: &str) -> Result { // FIXME(dennis): SQL injection // TODO(dennis): we use sql to find the script, the better way is use a function // such as `find_record_by_primary_key` in table_engine. - let sql = format!("select script from {} where name='{}'", self.name(), name); + let sql = format!( + "select script from {} where schema='{}' and name='{}'", + self.name(), + schema, + name + ); let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let plan = self .query_engine .statement_to_plan(stmt, Arc::new(QueryContext::new())) @@ -195,6 +204,11 @@ impl ScriptsTable { /// Build scripts table fn build_scripts_schema() -> Schema { let cols = vec![ + ColumnSchema::new( + "schema".to_string(), + ConcreteDataType::string_datatype(), + false, + ), ColumnSchema::new( "name".to_string(), ConcreteDataType::string_datatype(), diff --git a/src/servers/src/http/script.rs b/src/servers/src/http/script.rs index 25738e065c..84bddde435 100644 --- a/src/servers/src/http/script.rs +++ b/src/servers/src/http/script.rs @@ -51,16 +51,26 @@ pub async fn scripts( RawBody(body): RawBody, ) -> Json { if let Some(script_handler) = &state.script_handler { + let schema = params.schema.as_ref(); + + if schema.is_none() || schema.unwrap().is_empty() { + json_err!("invalid schema") + } + let name = params.name.as_ref(); if name.is_none() || name.unwrap().is_empty() { json_err!("invalid name"); } + let bytes = unwrap_or_json_err!(hyper::body::to_bytes(body).await); let script = unwrap_or_json_err!(String::from_utf8(bytes.to_vec())); - let body = match script_handler.insert_script(name.unwrap(), &script).await { + let body = match script_handler + .insert_script(schema.unwrap(), name.unwrap(), &script) + .await + { Ok(()) => JsonResponse::with_output(None), Err(e) => json_err!(format!("Insert script error: {e}"), e.status_code()), }; @@ -73,6 +83,7 @@ pub async fn scripts( #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct ScriptQuery { + pub schema: Option, pub name: Option, } @@ -84,6 +95,12 @@ pub async fn run_script( ) -> Json { if let Some(script_handler) = &state.script_handler { let start = Instant::now(); + let schema = params.schema.as_ref(); + + if schema.is_none() || schema.unwrap().is_empty() { + json_err!("invalid schema") + } + let name = params.name.as_ref(); if name.is_none() || name.unwrap().is_empty() { @@ -92,7 +109,9 @@ pub async fn run_script( // TODO(sunng87): query_context and db name resolution - let output = script_handler.execute_script(name.unwrap()).await; + let output = script_handler + .execute_script(schema.unwrap(), name.unwrap()) + .await; let resp = JsonResponse::from_output(vec![output]).await; Json(resp.with_execution_time(start.elapsed().as_millis())) diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index 2a1e59818e..edaad6f0ad 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -44,8 +44,8 @@ pub type ScriptHandlerRef = Arc; #[async_trait] pub trait ScriptHandler { - async fn insert_script(&self, name: &str, script: &str) -> Result<()>; - async fn execute_script(&self, name: &str) -> Result; + async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>; + async fn execute_script(&self, schema: &str, name: &str) -> Result; } #[async_trait] diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index 0ff2701533..5ce7d07ea8 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -104,7 +104,7 @@ def test(n): ) .await; assert!(!json.success(), "{json:?}"); - assert_eq!(json.error().unwrap(), "Invalid argument: invalid name"); + assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema"); let body = RawBody(Body::from(script)); let exec = create_script_query(); @@ -124,12 +124,16 @@ def test(n): fn create_script_query() -> Query { Query(script_handler::ScriptQuery { + schema: Some("test".to_string()), name: Some("test".to_string()), }) } fn create_invalid_script_query() -> Query { - Query(script_handler::ScriptQuery { name: None }) + Query(script_handler::ScriptQuery { + schema: None, + name: None, + }) } fn create_query() -> Query { diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 9e43fe4bfe..f62357b78c 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -105,7 +105,7 @@ impl SqlQueryHandler for DummyInstance { #[async_trait] impl ScriptHandler for DummyInstance { - async fn insert_script(&self, name: &str, script: &str) -> Result<()> { + async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()> { let script = self .py_engine .compile(script, CompileContext::default()) @@ -115,13 +115,15 @@ impl ScriptHandler for DummyInstance { self.scripts .write() .unwrap() - .insert(name.to_string(), Arc::new(script)); + .insert(format!("{schema}_{name}"), Arc::new(script)); Ok(()) } - async fn execute_script(&self, name: &str) -> Result { - let py_script = self.scripts.read().unwrap().get(name).unwrap().clone(); + async fn execute_script(&self, schema: &str, name: &str) -> Result { + let key = format!("{schema}_{name}"); + + let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone(); Ok(py_script.execute(EvalContext::default()).await.unwrap()) } diff --git a/src/servers/tests/py_script/mod.rs b/src/servers/tests/py_script/mod.rs index 3d9d9226b4..14b3b17116 100644 --- a/src/servers/tests/py_script/mod.rs +++ b/src/servers/tests/py_script/mod.rs @@ -33,7 +33,9 @@ async fn test_insert_py_udf_and_query() -> Result<()> { def double_that(col)->vector[u32]: return col*2 "#; - instance.insert_script("double_that", src).await?; + instance + .insert_script("schema_test", "double_that", src) + .await?; let res = instance .do_query("select double_that(uint32s) from numbers", query_ctx) .await diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 2daa73649d..66999808bb 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -317,7 +317,7 @@ pub async fn test_scripts_api(store_type: StorageType) { let client = TestClient::new(app); let res = client - .post("/v1/scripts?name=test") + .post("/v1/scripts?schema=schema_test&name=test") .body( r#" @copr(sql='select number from numbers limit 10', args=['number'], returns=['n']) @@ -334,7 +334,10 @@ def test(n): assert!(body.output().is_none()); // call script - let res = client.post("/v1/run-script?name=test").send().await; + let res = client + .post("/v1/run-script?schema=schema_test&name=test") + .send() + .await; assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap();