refactor: add schema column to the scripts table (#868)

This commit is contained in:
ShenJunkun
2023-02-07 11:07:32 +08:00
committed by GitHub
parent 5d62e193bd
commit afac885c10
11 changed files with 115 additions and 39 deletions

View File

@@ -22,13 +22,20 @@ use crate::metric;
#[async_trait]
impl ScriptHandler for Instance {
async fn insert_script(&self, name: &str, script: &str) -> servers::error::Result<()> {
async fn insert_script(
&self,
schema: &str,
name: &str,
script: &str,
) -> servers::error::Result<()> {
let _timer = timer!(metric::METRIC_HANDLE_SCRIPTS_ELAPSED);
self.script_executor.insert_script(name, script).await
self.script_executor
.insert_script(schema, name, script)
.await
}
async fn execute_script(&self, name: &str) -> servers::error::Result<Output> {
async fn execute_script(&self, schema: &str, name: &str) -> servers::error::Result<Output> {
let _timer = timer!(metric::METRIC_RUN_SCRIPT_ELAPSED);
self.script_executor.execute_script(name).await
self.script_executor.execute_script(schema, name).await
}
}

View File

@@ -71,10 +71,15 @@ mod python {
})
}
pub async fn insert_script(&self, name: &str, script: &str) -> servers::error::Result<()> {
pub async fn insert_script(
&self,
schema: &str,
name: &str,
script: &str,
) -> servers::error::Result<()> {
let _s = self
.script_manager
.insert_and_compile(name, script)
.insert_and_compile(schema, name, script)
.await
.map_err(|e| {
error!(e; "Instance failed to insert script");
@@ -85,9 +90,13 @@ mod python {
Ok(())
}
pub async fn execute_script(&self, name: &str) -> servers::error::Result<Output> {
pub async fn execute_script(
&self,
schema: &str,
name: &str,
) -> servers::error::Result<Output> {
self.script_manager
.execute(name)
.execute(schema, name)
.await
.map_err(|e| {
error!(e; "Instance failed to execute script");

View File

@@ -501,9 +501,14 @@ impl SqlQueryHandler for Instance {
#[async_trait]
impl ScriptHandler for Instance {
async fn insert_script(&self, name: &str, script: &str) -> server_error::Result<()> {
async fn insert_script(
&self,
schema: &str,
name: &str,
script: &str,
) -> server_error::Result<()> {
if let Some(handler) = &self.script_handler {
handler.insert_script(name, script).await
handler.insert_script(schema, name, script).await
} else {
server_error::NotSupportedSnafu {
feat: "Script execution in Frontend",
@@ -512,9 +517,9 @@ impl ScriptHandler for Instance {
}
}
async fn execute_script(&self, script: &str) -> server_error::Result<Output> {
async fn execute_script(&self, schema: &str, script: &str) -> server_error::Result<Output> {
if let Some(handler) = &self.script_handler {
handler.execute_script(script).await
handler.execute_script(schema, script).await
} else {
server_error::NotSupportedSnafu {
feat: "Script execution in Frontend",

View File

@@ -65,20 +65,25 @@ impl ScriptManager {
Ok(script)
}
pub async fn insert_and_compile(&self, name: &str, script: &str) -> Result<Arc<PyScript>> {
pub async fn insert_and_compile(
&self,
schema: &str,
name: &str,
script: &str,
) -> Result<Arc<PyScript>> {
let compiled_script = self.compile(name, script).await?;
self.table.insert(name, script).await?;
self.table.insert(schema, name, script).await?;
Ok(compiled_script)
}
pub async fn execute(&self, name: &str) -> Result<Output> {
pub async fn execute(&self, schema: &str, name: &str) -> Result<Output> {
let script = {
let s = self.compiled.read().unwrap().get(name).cloned();
if s.is_some() {
s
} else {
self.try_find_script_and_compile(name).await?
self.try_find_script_and_compile(schema, name).await?
}
};
@@ -90,8 +95,12 @@ impl ScriptManager {
.context(ExecutePythonSnafu { name })
}
async fn try_find_script_and_compile(&self, name: &str) -> Result<Option<Arc<PyScript>>> {
let script = self.table.find_script_by_name(name).await?;
async fn try_find_script_and_compile(
&self,
schema: &str,
name: &str,
) -> Result<Option<Arc<PyScript>>> {
let script = self.table.find_script_by_name(schema, name).await?;
Ok(Some(self.compile(name, &script).await?))
}
@@ -149,9 +158,11 @@ mod tests {
.unwrap();
catalog_manager.start().await.unwrap();
let schema = "schema";
let name = "test";
mgr.table
.insert(
schema,
name,
r#"
@copr(sql='select number from numbers limit 10', args=['number'], returns=['n'])
@@ -168,7 +179,7 @@ def test(n):
}
// try to find and compile
let script = mgr.try_find_script_and_compile(name).await.unwrap();
let script = mgr.try_find_script_and_compile(schema, name).await.unwrap();
assert!(script.is_some());
{

View File

@@ -61,8 +61,8 @@ impl ScriptsTable {
desc: Some("Scripts table".to_string()),
schema,
region_numbers: vec![0],
// name as primary key
primary_key_indices: vec![0],
//schema and name as primary key
primary_key_indices: vec![0, 1],
create_if_not_exists: true,
table_options: HashMap::default(),
};
@@ -86,8 +86,12 @@ impl ScriptsTable {
})
}
pub async fn insert(&self, name: &str, script: &str) -> Result<()> {
let mut columns_values: HashMap<String, VectorRef> = HashMap::with_capacity(7);
pub async fn insert(&self, schema: &str, name: &str, script: &str) -> Result<()> {
let mut columns_values: HashMap<String, VectorRef> = HashMap::with_capacity(8);
columns_values.insert(
"schema".to_string(),
Arc::new(StringVector::from(vec![schema])) as _,
);
columns_values.insert(
"name".to_string(),
Arc::new(StringVector::from(vec![name])) as _,
@@ -115,7 +119,6 @@ impl ScriptsTable {
"gmt_modified".to_string(),
Arc::new(TimestampMillisecondVector::from_slice(&[now])) as _,
);
let table = self
.catalog_manager
.table(
@@ -142,12 +145,18 @@ impl ScriptsTable {
Ok(())
}
pub async fn find_script_by_name(&self, name: &str) -> Result<String> {
pub async fn find_script_by_name(&self, schema: &str, name: &str) -> Result<String> {
// FIXME(dennis): SQL injection
// TODO(dennis): we use sql to find the script, the better way is use a function
// such as `find_record_by_primary_key` in table_engine.
let sql = format!("select script from {} where name='{}'", self.name(), name);
let sql = format!(
"select script from {} where schema='{}' and name='{}'",
self.name(),
schema,
name
);
let stmt = QueryLanguageParser::parse_sql(&sql).unwrap();
let plan = self
.query_engine
.statement_to_plan(stmt, Arc::new(QueryContext::new()))
@@ -195,6 +204,11 @@ impl ScriptsTable {
/// Build scripts table
fn build_scripts_schema() -> Schema {
let cols = vec![
ColumnSchema::new(
"schema".to_string(),
ConcreteDataType::string_datatype(),
false,
),
ColumnSchema::new(
"name".to_string(),
ConcreteDataType::string_datatype(),

View File

@@ -51,16 +51,26 @@ pub async fn scripts(
RawBody(body): RawBody,
) -> Json<JsonResponse> {
if let Some(script_handler) = &state.script_handler {
let schema = params.schema.as_ref();
if schema.is_none() || schema.unwrap().is_empty() {
json_err!("invalid schema")
}
let name = params.name.as_ref();
if name.is_none() || name.unwrap().is_empty() {
json_err!("invalid name");
}
let bytes = unwrap_or_json_err!(hyper::body::to_bytes(body).await);
let script = unwrap_or_json_err!(String::from_utf8(bytes.to_vec()));
let body = match script_handler.insert_script(name.unwrap(), &script).await {
let body = match script_handler
.insert_script(schema.unwrap(), name.unwrap(), &script)
.await
{
Ok(()) => JsonResponse::with_output(None),
Err(e) => json_err!(format!("Insert script error: {e}"), e.status_code()),
};
@@ -73,6 +83,7 @@ pub async fn scripts(
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ScriptQuery {
pub schema: Option<String>,
pub name: Option<String>,
}
@@ -84,6 +95,12 @@ pub async fn run_script(
) -> Json<JsonResponse> {
if let Some(script_handler) = &state.script_handler {
let start = Instant::now();
let schema = params.schema.as_ref();
if schema.is_none() || schema.unwrap().is_empty() {
json_err!("invalid schema")
}
let name = params.name.as_ref();
if name.is_none() || name.unwrap().is_empty() {
@@ -92,7 +109,9 @@ pub async fn run_script(
// TODO(sunng87): query_context and db name resolution
let output = script_handler.execute_script(name.unwrap()).await;
let output = script_handler
.execute_script(schema.unwrap(), name.unwrap())
.await;
let resp = JsonResponse::from_output(vec![output]).await;
Json(resp.with_execution_time(start.elapsed().as_millis()))

View File

@@ -44,8 +44,8 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
#[async_trait]
pub trait ScriptHandler {
async fn insert_script(&self, name: &str, script: &str) -> Result<()>;
async fn execute_script(&self, name: &str) -> Result<Output>;
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>;
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output>;
}
#[async_trait]

View File

@@ -104,7 +104,7 @@ def test(n):
)
.await;
assert!(!json.success(), "{json:?}");
assert_eq!(json.error().unwrap(), "Invalid argument: invalid name");
assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema");
let body = RawBody(Body::from(script));
let exec = create_script_query();
@@ -124,12 +124,16 @@ def test(n):
fn create_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery {
schema: Some("test".to_string()),
name: Some("test".to_string()),
})
}
fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery { name: None })
Query(script_handler::ScriptQuery {
schema: None,
name: None,
})
}
fn create_query() -> Query<http_handler::SqlQuery> {

View File

@@ -105,7 +105,7 @@ impl SqlQueryHandler for DummyInstance {
#[async_trait]
impl ScriptHandler for DummyInstance {
async fn insert_script(&self, name: &str, script: &str) -> Result<()> {
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()> {
let script = self
.py_engine
.compile(script, CompileContext::default())
@@ -115,13 +115,15 @@ impl ScriptHandler for DummyInstance {
self.scripts
.write()
.unwrap()
.insert(name.to_string(), Arc::new(script));
.insert(format!("{schema}_{name}"), Arc::new(script));
Ok(())
}
async fn execute_script(&self, name: &str) -> Result<Output> {
let py_script = self.scripts.read().unwrap().get(name).unwrap().clone();
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output> {
let key = format!("{schema}_{name}");
let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();
Ok(py_script.execute(EvalContext::default()).await.unwrap())
}

View File

@@ -33,7 +33,9 @@ async fn test_insert_py_udf_and_query() -> Result<()> {
def double_that(col)->vector[u32]:
return col*2
"#;
instance.insert_script("double_that", src).await?;
instance
.insert_script("schema_test", "double_that", src)
.await?;
let res = instance
.do_query("select double_that(uint32s) from numbers", query_ctx)
.await

View File

@@ -317,7 +317,7 @@ pub async fn test_scripts_api(store_type: StorageType) {
let client = TestClient::new(app);
let res = client
.post("/v1/scripts?name=test")
.post("/v1/scripts?schema=schema_test&name=test")
.body(
r#"
@copr(sql='select number from numbers limit 10', args=['number'], returns=['n'])
@@ -334,7 +334,10 @@ def test(n):
assert!(body.output().is_none());
// call script
let res = client.post("/v1/run-script?name=test").send().await;
let res = client
.post("/v1/run-script?schema=schema_test&name=test")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();