From 5566f34bd17a51fcd7052efd11b8e740c3857ba2 Mon Sep 17 00:00:00 2001 From: dennis zhuang Date: Mon, 18 Sep 2023 19:43:21 +0800 Subject: [PATCH] feat: make scripts table work again (#2420) * feat: make scripts table work again * chore: typo * fix: license header * Update src/table/src/metadata.rs Co-authored-by: Ruihang Xia * chore: cr comments --------- Co-authored-by: Ruihang Xia --- Cargo.lock | 5 + src/frontend/Cargo.toml | 1 + src/frontend/src/instance.rs | 2 +- src/frontend/src/instance/script.rs | 9 +- src/frontend/src/script.rs | 143 ++++++++++-- src/script/Cargo.toml | 4 + src/script/src/error.rs | 27 +-- src/script/src/lib.rs | 2 + src/script/src/manager.rs | 156 +++++++++---- src/script/src/python/engine.rs | 7 +- src/script/src/table.rs | 340 +++++++++++++++------------- src/script/src/test.rs | 78 +++++++ src/servers/src/http/script.rs | 19 +- src/servers/src/query_handler.rs | 9 +- src/servers/tests/mod.rs | 18 +- src/servers/tests/py_script/mod.rs | 83 +++++-- src/table/src/metadata.rs | 4 + 17 files changed, 620 insertions(+), 287 deletions(-) create mode 100644 src/script/src/test.rs diff --git a/Cargo.lock b/Cargo.lock index 6f8b418c62..689b92c80d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3240,6 +3240,7 @@ name = "frontend" version = "0.4.0-nightly" dependencies = [ "api", + "arc-swap", "arrow-flight", "async-compat", "async-stream", @@ -8427,6 +8428,8 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" name = "script" version = "0.4.0-nightly" dependencies = [ + "api", + "arc-swap", "arrow", "async-trait", "catalog", @@ -8452,6 +8455,7 @@ dependencies = [ "log-store", "mito", "once_cell", + "operator", "paste", "pyo3", "query", @@ -8466,6 +8470,7 @@ dependencies = [ "rustpython-stdlib", "rustpython-vm", "serde", + "servers", "session", "snafu", "sql", diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 4d92780041..285ad121ba 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -11,6 +11,7 @@ testing = [] [dependencies] api = { workspace = true } +arc-swap = "1.0" arrow-flight.workspace = true async-compat = "0.2" async-stream.workspace = true diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 4a2f13df4e..fabedf7e4e 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -393,7 +393,7 @@ impl FrontendInstance for Instance { heartbeat_task.start().await?; } - self.script_executor.start(self).await?; + self.script_executor.start(self)?; futures::future::try_join_all(self.servers.values().map(start_server)) .await diff --git a/src/frontend/src/instance/script.rs b/src/frontend/src/instance/script.rs index d3eb5cb29f..f0aee2074b 100644 --- a/src/frontend/src/instance/script.rs +++ b/src/frontend/src/instance/script.rs @@ -18,6 +18,7 @@ use async_trait::async_trait; use common_query::Output; use common_telemetry::timer; use servers::query_handler::ScriptHandler; +use session::context::QueryContextRef; use crate::instance::Instance; use crate::metrics; @@ -26,25 +27,25 @@ use crate::metrics; impl ScriptHandler for Instance { async fn insert_script( &self, - schema: &str, + query_ctx: QueryContextRef, name: &str, script: &str, ) -> servers::error::Result<()> { let _timer = timer!(metrics::METRIC_HANDLE_SCRIPTS_ELAPSED); self.script_executor - .insert_script(schema, name, script) + .insert_script(query_ctx, name, script) .await } async fn execute_script( &self, - schema: &str, + query_ctx: QueryContextRef, name: &str, params: HashMap, ) -> servers::error::Result { let _timer = timer!(metrics::METRIC_RUN_SCRIPT_ELAPSED); self.script_executor - .execute_script(schema, name, params) + .execute_script(query_ctx, name, params) .await } } diff --git a/src/frontend/src/script.rs b/src/frontend/src/script.rs index f31cee8469..03bd2db0a7 100644 --- a/src/frontend/src/script.rs +++ b/src/frontend/src/script.rs @@ -13,12 +13,17 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::Arc; use catalog::CatalogManagerRef; use common_query::Output; use query::QueryEngineRef; +use servers::query_handler::grpc::GrpcQueryHandler; +use session::context::QueryContextRef; -use crate::error::Result; +use crate::error::{Error, Result}; + +type FrontendGrpcQueryHandlerRef = Arc + Send + Sync>; #[cfg(not(feature = "python"))] mod dummy { @@ -34,13 +39,13 @@ mod dummy { Ok(Self {}) } - pub async fn start(&self) -> Result<()> { + pub fn start(&self, instance: &Instance) -> Result<()> { Ok(()) } pub async fn insert_script( &self, - _schema: &str, + _query_ctx: QueryContextRef, _name: &str, _script: &str, ) -> servers::error::Result<()> { @@ -49,7 +54,7 @@ mod dummy { pub async fn execute_script( &self, - _schema: &str, + _query_ctx: QueryContextRef, _name: &str, _params: HashMap, ) -> servers::error::Result { @@ -63,10 +68,11 @@ mod python { use api::v1::ddl_request::Expr; use api::v1::greptime_request::Request; use api::v1::{CreateTableExpr, DdlRequest}; + use arc_swap::ArcSwap; use catalog::RegisterSystemTableRequest; use common_error::ext::BoxedError; use common_meta::table_name::TableName; - use common_telemetry::logging::error; + use common_telemetry::{error, info}; use operator::expr_factory; use script::manager::ScriptManager; use servers::query_handler::grpc::GrpcQueryHandler; @@ -78,8 +84,33 @@ mod python { use crate::error::{CatalogSnafu, InvalidSystemTableDefSnafu, TableNotFoundSnafu}; use crate::instance::Instance; + /// A placeholder for the real gRPC handler. + /// It is temporary and will be replaced soon. + struct DummyHandler; + + impl DummyHandler { + fn arc() -> Arc { + Arc::new(Self {}) + } + } + + #[async_trait::async_trait] + impl GrpcQueryHandler for DummyHandler { + type Error = Error; + + async fn do_query( + &self, + _query: Request, + _ctx: QueryContextRef, + ) -> std::result::Result { + unreachable!(); + } + } + pub struct ScriptExecutor { - script_manager: ScriptManager, + script_manager: ScriptManager, + grpc_handler: ArcSwap, + catalog_manager: CatalogManagerRef, } impl ScriptExecutor { @@ -87,21 +118,42 @@ mod python { catalog_manager: CatalogManagerRef, query_engine: QueryEngineRef, ) -> Result { + let grpc_handler = DummyHandler::arc(); Ok(Self { - script_manager: ScriptManager::new(catalog_manager, query_engine) + grpc_handler: ArcSwap::new(Arc::new(grpc_handler.clone() as _)), + script_manager: ScriptManager::new(grpc_handler as _, query_engine) .await .context(crate::error::StartScriptManagerSnafu)?, + catalog_manager, }) } - pub async fn start(&self, instance: &Instance) -> Result<()> { + pub fn start(&self, instance: &Instance) -> Result<()> { + let handler = Arc::new(instance.clone()); + self.grpc_handler.store(Arc::new(handler.clone() as _)); + self.script_manager + .start(handler) + .context(crate::error::StartScriptManagerSnafu)?; + + Ok(()) + } + + /// Create scripts table for the specific catalog if it's not exists. + /// The function is idempotent and safe to be called more than once for the same catalog + async fn create_scripts_table_if_need(&self, catalog: &str) -> Result<()> { + let scripts_table = self.script_manager.get_scripts_table(catalog); + + if scripts_table.is_some() { + return Ok(()); + } + let RegisterSystemTableRequest { create_table_request: request, open_hook, - } = self.script_manager.create_table_request(); + } = self.script_manager.create_table_request(catalog); - if let Some(table) = instance - .catalog_manager() + if let Some(table) = self + .catalog_manager .table( &request.catalog_name, &request.schema_name, @@ -111,9 +163,11 @@ mod python { .context(CatalogSnafu)? { if let Some(open_hook) = open_hook { - (open_hook)(table).await.context(CatalogSnafu)?; + (open_hook)(table.clone()).await.context(CatalogSnafu)?; } + self.script_manager.insert_scripts_table(catalog, table); + return Ok(()); } @@ -125,7 +179,9 @@ mod python { let expr = Self::create_table_expr(request)?; - let _ = instance + let _ = self + .grpc_handler + .load() .do_query( Request::Ddl(DdlRequest { expr: Some(Expr::CreateTable(expr)), @@ -134,8 +190,8 @@ mod python { ) .await?; - let table = instance - .catalog_manager() + let table = self + .catalog_manager .table( &table_name.catalog_name, &table_name.schema_name, @@ -148,9 +204,16 @@ mod python { })?; if let Some(open_hook) = open_hook { - (open_hook)(table).await.context(CatalogSnafu)?; + (open_hook)(table.clone()).await.context(CatalogSnafu)?; } + info!( + "Created scripts table {}.", + table.table_info().full_table_name() + ); + + self.script_manager.insert_scripts_table(catalog, table); + Ok(()) } @@ -196,16 +259,31 @@ mod python { pub async fn insert_script( &self, - schema: &str, + query_ctx: QueryContextRef, name: &str, script: &str, ) -> servers::error::Result<()> { - let _s = self - .script_manager - .insert_and_compile(schema, name, script) + self.create_scripts_table_if_need(query_ctx.current_catalog()) .await .map_err(|e| { - error!(e; "Instance failed to insert script"); + error!(e; "Failed to create scripts table"); + servers::error::InternalSnafu { + err_msg: e.to_string(), + } + .build() + })?; + + let _s = self + .script_manager + .insert_and_compile( + query_ctx.current_catalog(), + query_ctx.current_schema(), + name, + script, + ) + .await + .map_err(|e| { + error!(e; "Failed to insert script"); BoxedError::new(e) }) .context(servers::error::InsertScriptSnafu { name })?; @@ -215,15 +293,30 @@ mod python { pub async fn execute_script( &self, - schema: &str, + query_ctx: QueryContextRef, name: &str, params: HashMap, ) -> servers::error::Result { - self.script_manager - .execute(schema, name, params) + self.create_scripts_table_if_need(query_ctx.current_catalog()) .await .map_err(|e| { - error!(e; "Instance failed to execute script"); + error!(e; "Failed to create scripts table"); + servers::error::InternalSnafu { + err_msg: e.to_string(), + } + .build() + })?; + + self.script_manager + .execute( + query_ctx.current_catalog(), + query_ctx.current_schema(), + name, + params, + ) + .await + .map_err(|e| { + error!(e; "Failed to execute script"); BoxedError::new(e) }) .context(servers::error::ExecuteScriptSnafu { name }) diff --git a/src/script/Cargo.toml b/src/script/Cargo.toml index 20dc96b09f..3aab59a686 100644 --- a/src/script/Cargo.toml +++ b/src/script/Cargo.toml @@ -24,6 +24,8 @@ python = [ ] [dependencies] +api.workspace = true +arc-swap = "1.0" arrow.workspace = true async-trait.workspace = true catalog = { workspace = true } @@ -62,6 +64,7 @@ rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = tru "default", "codegen", ] } +servers.workspace = true session = { workspace = true } snafu = { version = "0.7", features = ["backtraces"] } sql = { workspace = true } @@ -75,6 +78,7 @@ common-test-util = { workspace = true } criterion = { version = "0.4", features = ["html_reports", "async_tokio"] } log-store = { workspace = true } mito = { workspace = true } +operator.workspace = true rayon = "1.0" ron = "0.7" serde = { version = "1.0", features = ["derive"] } diff --git a/src/script/src/error.rs b/src/script/src/error.rs index feaf450f76..68dcc324f8 100644 --- a/src/script/src/error.rs +++ b/src/script/src/error.rs @@ -14,28 +14,16 @@ use std::any::Any; -use common_error::ext::ErrorExt; +use common_error::ext::{BoxedError, ErrorExt}; use common_error::status_code::StatusCode; use snafu::{Location, Snafu}; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] pub enum Error { - #[snafu(display("Failed to find scripts table, source: {}", source))] - FindScriptsTable { - location: Location, - source: catalog::error::Error, - }, - #[snafu(display("Failed to find column in scripts table, name: {}", name))] FindColumnInScriptsTable { name: String, location: Location }, - #[snafu(display("Failed to register scripts table, source: {}", source))] - RegisterScriptsTable { - location: Location, - source: catalog::error::Error, - }, - #[snafu(display("Scripts table not found"))] ScriptsTableNotFound { location: Location }, @@ -47,7 +35,7 @@ pub enum Error { InsertScript { name: String, location: Location, - source: table::error::Error, + source: BoxedError, }, #[snafu(display("Failed to compile python script, name: {}, source: {}", name, source))] @@ -67,13 +55,6 @@ pub enum Error { #[snafu(display("Script not found, name: {}", name))] ScriptNotFound { location: Location, name: String }, - #[snafu(display("Failed to find script by name: {}", name))] - FindScript { - name: String, - location: Location, - source: query::error::Error, - }, - #[snafu(display("Failed to collect record batch, source: {}", source))] CollectRecords { location: Location, @@ -104,12 +85,8 @@ impl ErrorExt for Error { match self { FindColumnInScriptsTable { .. } | CastType { .. } => StatusCode::Unexpected, ScriptsTableNotFound { .. } => StatusCode::TableNotFound, - RegisterScriptsTable { source, .. } | FindScriptsTable { source, .. } => { - source.status_code() - } InsertScript { source, .. } => source.status_code(), CompilePython { source, .. } | ExecutePython { source, .. } => source.status_code(), - FindScript { source, .. } => source.status_code(), CollectRecords { source, .. } => source.status_code(), ScriptNotFound { .. } => StatusCode::InvalidArguments, BuildDfLogicalPlan { .. } => StatusCode::Internal, diff --git a/src/script/src/lib.rs b/src/script/src/lib.rs index 9f88abcc5d..eddbb4d2e0 100644 --- a/src/script/src/lib.rs +++ b/src/script/src/lib.rs @@ -22,3 +22,5 @@ pub mod manager; #[cfg(feature = "python")] pub mod python; pub mod table; +#[cfg(test)] +mod test; diff --git a/src/script/src/manager.rs b/src/script/src/manager.rs index f5a2a20723..a80140edf5 100644 --- a/src/script/src/manager.rs +++ b/src/script/src/manager.rs @@ -16,54 +16,70 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; -use catalog::{CatalogManagerRef, OpenSystemTableHook, RegisterSystemTableRequest}; -use common_catalog::consts::{ - default_engine, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, SCRIPTS_TABLE_ID, -}; +use arc_swap::ArcSwap; +use catalog::{OpenSystemTableHook, RegisterSystemTableRequest}; +use common_catalog::consts::{default_engine, DEFAULT_SCHEMA_NAME, SCRIPTS_TABLE_ID}; +use common_error::ext::ErrorExt; use common_query::Output; use common_telemetry::logging; use futures::future::FutureExt; use query::QueryEngineRef; +use servers::query_handler::grpc::GrpcQueryHandlerRef; use snafu::{OptionExt, ResultExt}; use table::requests::{CreateTableRequest, TableOptions}; use table::TableRef; use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine}; -use crate::error::{CompilePythonSnafu, ExecutePythonSnafu, Result, ScriptNotFoundSnafu}; +use crate::error::{ + CompilePythonSnafu, ExecutePythonSnafu, Result, ScriptNotFoundSnafu, ScriptsTableNotFoundSnafu, +}; use crate::python::{PyEngine, PyScript}; -use crate::table::{build_scripts_schema, ScriptsTable, SCRIPTS_TABLE_NAME}; +use crate::table::{ + build_scripts_schema, get_primary_key_indices, ScriptsTable, ScriptsTableRef, + SCRIPTS_TABLE_NAME, +}; -pub struct ScriptManager { +pub struct ScriptManager { compiled: RwLock>>, py_engine: PyEngine, + grpc_handler: ArcSwap>, + // Catalog name -> `[ScriptsTable]` + tables: RwLock>>, query_engine: QueryEngineRef, - table: ScriptsTable, } -impl ScriptManager { +impl ScriptManager { pub async fn new( - catalog_manager: CatalogManagerRef, + grpc_handler: GrpcQueryHandlerRef, query_engine: QueryEngineRef, ) -> Result { Ok(Self { compiled: RwLock::new(HashMap::default()), py_engine: PyEngine::new(query_engine.clone()), - query_engine: query_engine.clone(), - table: ScriptsTable::new(catalog_manager, query_engine).await?, + query_engine, + grpc_handler: ArcSwap::new(Arc::new(grpc_handler)), + tables: RwLock::new(HashMap::default()), }) } - pub fn create_table_request(&self) -> RegisterSystemTableRequest { + pub fn start(&self, grpc_handler: GrpcQueryHandlerRef) -> Result<()> { + self.grpc_handler.store(Arc::new(grpc_handler)); + + Ok(()) + } + + pub fn create_table_request(&self, catalog: &str) -> RegisterSystemTableRequest { let request = CreateTableRequest { id: SCRIPTS_TABLE_ID, - catalog_name: DEFAULT_CATALOG_NAME.to_string(), + catalog_name: catalog.to_string(), + // TODO(dennis): put the scripts table into `system` schema? + // We always put the scripts table into `public` schema right now. schema_name: DEFAULT_SCHEMA_NAME.to_string(), table_name: SCRIPTS_TABLE_NAME.to_string(), desc: Some("GreptimeDB scripts table for Python".to_string()), schema: build_scripts_schema(), region_numbers: vec![0], - // 'schema' and 'name' are primary keys - primary_key_indices: vec![0, 1], + primary_key_indices: get_primary_key_indices(), create_if_not_exists: true, table_options: TableOptions::default(), engine: default_engine().to_string(), @@ -73,7 +89,7 @@ impl ScriptManager { let hook: OpenSystemTableHook = Box::new(move |table: TableRef| { let query_engine = query_engine.clone(); - async move { ScriptsTable::recompile_register_udf(table, query_engine.clone()).await } + async move { ScriptsTable::::recompile_register_udf(table, query_engine.clone()).await } .boxed() }); @@ -112,19 +128,48 @@ impl ScriptManager { .context(CompilePythonSnafu { name }) } + /// Get the scripts table in the catalog + pub fn get_scripts_table(&self, catalog: &str) -> Option> { + self.tables.read().unwrap().get(catalog).cloned() + } + + /// Insert a scripts table. + pub fn insert_scripts_table(&self, catalog: &str, table: TableRef) { + let mut tables = self.tables.write().unwrap(); + + if tables.get(catalog).is_some() { + return; + } + + tables.insert( + catalog.to_string(), + Arc::new(ScriptsTable::new( + table, + self.grpc_handler.load().as_ref().clone(), + self.query_engine.clone(), + )), + ); + } + pub async fn insert_and_compile( &self, + catalog: &str, schema: &str, name: &str, script: &str, ) -> Result> { let compiled_script = self.compile(name, script).await?; - self.table.insert(schema, name, script).await?; + self.get_scripts_table(catalog) + .context(ScriptsTableNotFoundSnafu)? + .insert(schema, name, script) + .await?; + Ok(compiled_script) } pub async fn execute( &self, + catalog: &str, schema: &str, name: &str, params: HashMap, @@ -135,7 +180,8 @@ impl ScriptManager { if s.is_some() { s } else { - self.try_find_script_and_compile(schema, name).await? + self.try_find_script_and_compile(catalog, schema, name) + .await? } }; @@ -149,10 +195,15 @@ impl ScriptManager { async fn try_find_script_and_compile( &self, + catalog: &str, schema: &str, name: &str, ) -> Result>> { - let script = self.table.find_script_by_name(schema, name).await?; + let script = self + .get_scripts_table(catalog) + .context(ScriptsTableNotFoundSnafu)? + .find_script_by_name(schema, name) + .await?; Ok(Some(self.compile(name, &script).await?)) } @@ -160,50 +211,67 @@ impl ScriptManager { #[cfg(test)] mod tests { - use catalog::memory::MemoryCatalogManager; - use query::QueryEngineFactory; - use super::*; + use crate::test::setup_scripts_manager; - #[ignore = "script engine is temporary disabled"] #[tokio::test] async fn test_insert_find_compile_script() { common_telemetry::init_default_ut_logging(); - let catalog_manager = MemoryCatalogManager::new(); - - let factory = QueryEngineFactory::new(catalog_manager.clone(), None, false); - let query_engine = factory.query_engine(); - let mgr = ScriptManager::new(catalog_manager.clone(), query_engine) - .await - .unwrap(); + let catalog = "greptime"; let schema = "schema"; let name = "test"; - mgr.table - .insert( - schema, - name, - r#" -@copr(sql='select number from numbers limit 10', args=['number'], returns=['n']) -def test(n): - return n + 1; -"#, - ) - .await - .unwrap(); + let script = r#" +@copr(returns=['n']) +def test() -> vector[str]: + return 'hello'; +"#; + + let mgr = setup_scripts_manager(catalog, schema, name, script).await; { let cached = mgr.compiled.read().unwrap(); assert!(cached.get(name).is_none()); } + mgr.insert_and_compile(catalog, schema, name, script) + .await + .unwrap(); + + { + let cached = mgr.compiled.read().unwrap(); + assert!(cached.get(name).is_some()); + } + // try to find and compile - let script = mgr.try_find_script_and_compile(schema, name).await.unwrap(); + let script = mgr + .try_find_script_and_compile(catalog, schema, name) + .await + .unwrap(); let _ = script.unwrap(); { let cached = mgr.compiled.read().unwrap(); let _ = cached.get(name).unwrap(); } + + // execute script + let output = mgr + .execute(catalog, schema, name, HashMap::new()) + .await + .unwrap(); + + match output { + Output::RecordBatches(batches) => { + let expected = "\ ++-------+ +| n | ++-------+ +| hello | ++-------+"; + assert_eq!(expected, batches.pretty_print().unwrap()); + } + _ => unreachable!(), + } } } diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index a2879ab001..5ab4f06298 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -75,6 +75,7 @@ impl PyUDF { fn register_as_udf(zelf: Arc) { FUNCTION_REGISTRY.register(zelf) } + fn register_to_query_engine(zelf: Arc, engine: QueryEngineRef) { engine.register_function(zelf) } @@ -138,10 +139,12 @@ impl Function for PyUDF { } } } + + // The Volatility should be volatile, the return value from evaluation may be changed. if know_all_types { - Signature::variadic(arg_types, Volatility::Immutable) + Signature::variadic(arg_types, Volatility::Volatile) } else { - Signature::any(self.copr.arg_types.len(), Volatility::Immutable) + Signature::any(self.copr.arg_types.len(), Volatility::Volatile) } } diff --git a/src/script/src/table.rs b/src/script/src/table.rs index 66910eb4cf..6e52f7139d 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -13,49 +13,68 @@ // limitations under the License. //! Scripts table -use std::collections::HashMap; use std::sync::Arc; +use api::helper::ColumnDataTypeWrapper; +use api::v1::greptime_request::Request; +use api::v1::value::ValueData; +use api::v1::{ + ColumnDataType, ColumnSchema as PbColumnSchema, Row, RowInsertRequest, RowInsertRequests, Rows, + SemanticType, +}; use catalog::error::CompileScriptInternalSnafu; -use catalog::CatalogManagerRef; -use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use common_catalog::format_full_table_name; -use common_error::ext::BoxedError; +use common_error::ext::{BoxedError, ErrorExt}; use common_query::Output; use common_recordbatch::{util as record_util, RecordBatch, SendableRecordBatchStream}; use common_telemetry::logging; use common_time::util; use datafusion::datasource::DefaultTableSource; +use datafusion::logical_expr::{and, col, lit}; use datafusion_common::TableReference; use datafusion_expr::LogicalPlanBuilder; -use datatypes::prelude::{ConcreteDataType, ScalarVector}; +use datatypes::prelude::ScalarVector; use datatypes::schema::{ColumnSchema, RawSchema}; -use datatypes::vectors::{StringVector, TimestampMillisecondVector, Vector, VectorRef}; -use query::parser::QueryLanguageParser; +use datatypes::vectors::{StringVector, Vector}; use query::plan::LogicalPlan; use query::QueryEngineRef; -use session::context::QueryContextBuilder; +use servers::query_handler::grpc::GrpcQueryHandlerRef; +use session::context::{QueryContextBuilder, QueryContextRef}; use snafu::{ensure, OptionExt, ResultExt}; -use table::requests::InsertRequest; +use table::metadata::TableInfo; use table::table::adapter::DfTableProviderAdapter; use table::TableRef; use crate::error::{ BuildDfLogicalPlanSnafu, CastTypeSnafu, CollectRecordsSnafu, ExecuteInternalStatementSnafu, - FindColumnInScriptsTableSnafu, FindScriptSnafu, FindScriptsTableSnafu, InsertScriptSnafu, - Result, ScriptNotFoundSnafu, ScriptsTableNotFoundSnafu, + FindColumnInScriptsTableSnafu, InsertScriptSnafu, Result, ScriptNotFoundSnafu, }; use crate::python::PyScript; pub const SCRIPTS_TABLE_NAME: &str = "scripts"; -pub struct ScriptsTable { - catalog_manager: CatalogManagerRef, +pub type ScriptsTableRef = Arc>; + +/// The scripts table that keeps the script content etc. +pub struct ScriptsTable { + table: TableRef, + grpc_handler: GrpcQueryHandlerRef, query_engine: QueryEngineRef, - name: String, } -impl ScriptsTable { +impl ScriptsTable { + /// Create a new `[ScriptsTable]` based on the table. + pub fn new( + table: TableRef, + grpc_handler: GrpcQueryHandlerRef, + query_engine: QueryEngineRef, + ) -> Self { + Self { + table, + grpc_handler, + query_engine, + } + } + fn get_str_col_by_name<'a>(record: &'a RecordBatch, name: &str) -> Result<&'a StringVector> { let column = record .column_by_name(name) @@ -79,6 +98,8 @@ impl ScriptsTable { table: TableRef, query_engine: QueryEngineRef, ) -> catalog::error::Result<()> { + let table_info = table.table_info(); + let rbs = Self::table_full_scan(table, &query_engine) .await .map_err(BoxedError::new) @@ -108,6 +129,12 @@ impl ScriptsTable { script_list.extend(part_of_scripts_list); } + logging::info!( + "Found {} scripts in {}", + script_list.len(), + table_info.full_table_name() + ); + for (name, script) in script_list { match PyScript::from_script(&script, query_engine.clone()) { Ok(script) => { @@ -129,123 +156,86 @@ impl ScriptsTable { Ok(()) } - pub fn new_empty( - catalog_manager: CatalogManagerRef, - query_engine: QueryEngineRef, - ) -> Result { - Ok(Self { - catalog_manager, - query_engine, - name: format_full_table_name( - DEFAULT_CATALOG_NAME, - DEFAULT_SCHEMA_NAME, - SCRIPTS_TABLE_NAME, - ), - }) - } - - pub async fn new( - catalog_manager: CatalogManagerRef, - query_engine: QueryEngineRef, - ) -> Result { - Ok(Self { - catalog_manager, - query_engine, - name: format_full_table_name( - DEFAULT_CATALOG_NAME, - DEFAULT_SCHEMA_NAME, - SCRIPTS_TABLE_NAME, - ), - }) - } - pub async fn insert(&self, schema: &str, name: &str, script: &str) -> Result<()> { let now = util::current_time_millis(); - let columns_values: HashMap = HashMap::from([ - ( - "schema".to_string(), - Arc::new(StringVector::from(vec![schema])) as VectorRef, - ), - ("name".to_string(), Arc::new(StringVector::from(vec![name]))), - ( - "script".to_string(), - Arc::new(StringVector::from(vec![script])) as VectorRef, - ), - ( - "engine".to_string(), - // TODO(dennis): we only supports python right now. - Arc::new(StringVector::from(vec!["python"])) as VectorRef, - ), - ( - "timestamp".to_string(), - // Timestamp in key part is intentionally left to 0 - Arc::new(TimestampMillisecondVector::from_slice([0])) as VectorRef, - ), - ( - "gmt_created".to_string(), - Arc::new(TimestampMillisecondVector::from_slice([now])) as VectorRef, - ), - ( - "gmt_modified".to_string(), - Arc::new(TimestampMillisecondVector::from_slice([now])) as VectorRef, - ), - ]); - let table = self - .catalog_manager - .table( - DEFAULT_CATALOG_NAME, - DEFAULT_SCHEMA_NAME, - SCRIPTS_TABLE_NAME, - ) - .await - .context(FindScriptsTableSnafu)? - .context(ScriptsTableNotFoundSnafu)?; - let _ = table - .insert(InsertRequest { - catalog_name: DEFAULT_CATALOG_NAME.to_string(), - schema_name: DEFAULT_SCHEMA_NAME.to_string(), - table_name: SCRIPTS_TABLE_NAME.to_string(), - columns_values, - region_number: 0, - }) + let table_info = self.table.table_info(); + + let insert = RowInsertRequest { + table_name: SCRIPTS_TABLE_NAME.to_string(), + rows: Some(Rows { + schema: build_insert_column_schemas(), + rows: vec![Row { + values: vec![ + ValueData::StringValue(schema.to_string()).into(), + ValueData::StringValue(name.to_string()).into(), + // TODO(dennis): we only supports python right now. + ValueData::StringValue("python".to_string()).into(), + ValueData::StringValue(script.to_string()).into(), + // Timestamp in key part is intentionally left to 0 + ValueData::TimestampMillisecondValue(0).into(), + ValueData::TimestampMillisecondValue(now).into(), + ], + }], + }), + }; + + let requests = RowInsertRequests { + inserts: vec![insert], + }; + + let output = self + .grpc_handler + .do_query(Request::RowInserts(requests), query_ctx(&table_info)) .await + .map_err(BoxedError::new) .context(InsertScriptSnafu { name })?; - logging::info!("Inserted script: name={} into scripts table.", name); + logging::info!( + "Inserted script: {} into scripts table: {}, output: {:?}.", + name, + table_info.full_table_name(), + output + ); Ok(()) } pub async fn find_script_by_name(&self, schema: &str, name: &str) -> Result { - // FIXME(dennis): SQL injection - // TODO(dennis): we use sql to find the script, the better way is use a function - // such as `find_record_by_primary_key` in table_engine. - let sql = format!( - "select script from {} where schema='{}' and name='{}'", - self.name(), - schema, - name + let table_info = self.table.table_info(); + + let table_name = TableReference::full( + table_info.catalog_name.clone(), + table_info.schema_name.clone(), + table_info.name.clone(), ); - let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); - let ctx = QueryContextBuilder::default().build(); - let plan = self - .query_engine - .planner() - .plan(stmt, ctx.clone()) - .await - .unwrap(); + let table_provider = Arc::new(DfTableProviderAdapter::new(self.table.clone())); + let table_source = Arc::new(DefaultTableSource::new(table_provider)); - let stream = match self + let plan = LogicalPlanBuilder::scan(table_name, table_source, None) + .context(BuildDfLogicalPlanSnafu)? + .filter(and( + col("schema").eq(lit(schema)), + col("name").eq(lit(name)), + )) + .context(BuildDfLogicalPlanSnafu)? + .project(vec![col("script")]) + .context(BuildDfLogicalPlanSnafu)? + .build() + .context(BuildDfLogicalPlanSnafu)?; + + let output = self .query_engine - .execute(plan, ctx) + .execute(LogicalPlan::DfPlan(plan), query_ctx(&table_info)) .await - .context(FindScriptSnafu { name })? - { + .context(ExecuteInternalStatementSnafu)?; + let stream = match output { Output::Stream(stream) => stream, + Output::RecordBatches(record_batches) => record_batches.as_stream(), _ => unreachable!(), }; + let records = record_util::collect(stream) .await .context(CollectRecordsSnafu)?; @@ -267,12 +257,9 @@ impl ScriptsTable { })?; assert_eq!(script_column.len(), 1); - Ok(script_column.get_data(0).unwrap().to_string()) - } - #[inline] - pub fn name(&self) -> &str { - &self.name + // Safety: asserted above + Ok(script_column.get_data(0).unwrap().to_string()) } async fn table_full_scan( @@ -295,10 +282,7 @@ impl ScriptsTable { .context(BuildDfLogicalPlanSnafu)?; let output = query_engine - .execute( - LogicalPlan::DfPlan(plan), - QueryContextBuilder::default().build(), - ) + .execute(LogicalPlan::DfPlan(plan), query_ctx(&table_info)) .await .context(ExecuteInternalStatementSnafu)?; let stream = match output { @@ -310,46 +294,80 @@ impl ScriptsTable { } } +/// Build the inserted column schemas +fn build_insert_column_schemas() -> Vec { + vec![ + // The schema that script belongs to. + PbColumnSchema { + column_name: "schema".to_string(), + datatype: ColumnDataType::String.into(), + semantic_type: SemanticType::Tag.into(), + }, + PbColumnSchema { + column_name: "name".to_string(), + datatype: ColumnDataType::String.into(), + semantic_type: SemanticType::Tag.into(), + }, + PbColumnSchema { + column_name: "engine".to_string(), + datatype: ColumnDataType::String.into(), + semantic_type: SemanticType::Tag.into(), + }, + PbColumnSchema { + column_name: "script".to_string(), + datatype: ColumnDataType::String.into(), + semantic_type: SemanticType::Field.into(), + }, + PbColumnSchema { + column_name: "greptime_timestamp".to_string(), + datatype: ColumnDataType::TimestampMillisecond.into(), + semantic_type: SemanticType::Timestamp.into(), + }, + PbColumnSchema { + column_name: "gmt_modified".to_string(), + datatype: ColumnDataType::TimestampMillisecond.into(), + semantic_type: SemanticType::Field.into(), + }, + ] +} + +fn query_ctx(table_info: &TableInfo) -> QueryContextRef { + QueryContextBuilder::default() + .current_catalog(table_info.catalog_name.to_string()) + .current_schema(table_info.schema_name.to_string()) + .build() +} + +/// Returns the scripts schema's primary key indices +pub fn get_primary_key_indices() -> Vec { + let mut indices = vec![]; + for (index, c) in build_insert_column_schemas().into_iter().enumerate() { + if c.semantic_type == (SemanticType::Tag as i32) { + indices.push(index); + } + } + + indices +} + /// Build scripts table pub fn build_scripts_schema() -> RawSchema { - let cols = vec![ - ColumnSchema::new( - "schema".to_string(), - ConcreteDataType::string_datatype(), - false, - ), - ColumnSchema::new( - "name".to_string(), - ConcreteDataType::string_datatype(), - false, - ), - ColumnSchema::new( - "script".to_string(), - ConcreteDataType::string_datatype(), - false, - ), - ColumnSchema::new( - "engine".to_string(), - ConcreteDataType::string_datatype(), - false, - ), - ColumnSchema::new( - "timestamp".to_string(), - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ) - .with_time_index(true), - ColumnSchema::new( - "gmt_created".to_string(), - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ), - ColumnSchema::new( - "gmt_modified".to_string(), - ConcreteDataType::timestamp_millisecond_datatype(), - false, - ), - ]; + let cols = build_insert_column_schemas() + .into_iter() + .map(|c| { + let cs = ColumnSchema::new( + c.column_name, + // Safety: the type always exists + ColumnDataTypeWrapper::try_new(c.datatype).unwrap().into(), + false, + ); + if c.semantic_type == SemanticType::Timestamp as i32 { + cs.with_time_index(true) + } else { + cs + } + }) + .collect(); RawSchema::new(cols) } diff --git a/src/script/src/test.rs b/src/script/src/test.rs new file mode 100644 index 0000000000..c937c4afd2 --- /dev/null +++ b/src/script/src/test.rs @@ -0,0 +1,78 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use api::v1::greptime_request::Request; +use async_trait::async_trait; +use catalog::memory::MemoryCatalogManager; +use common_query::Output; +use common_recordbatch::RecordBatch; +use datatypes::prelude::ConcreteDataType; +use datatypes::schema::{ColumnSchema, Schema}; +use datatypes::vectors::{StringVector, VectorRef}; +use query::QueryEngineFactory; +use servers::query_handler::grpc::GrpcQueryHandler; +use session::context::QueryContextRef; +use table::test_util::MemTable; + +use crate::error::{Error, Result}; +use crate::manager::ScriptManager; + +/// Setup the scripts table and create a script manager. +pub async fn setup_scripts_manager( + catalog: &str, + schema: &str, + name: &str, + script: &str, +) -> ScriptManager { + let column_schemas = vec![ + ColumnSchema::new("script", ConcreteDataType::string_datatype(), false), + ColumnSchema::new("schema", ConcreteDataType::string_datatype(), false), + ColumnSchema::new("name", ConcreteDataType::string_datatype(), false), + ]; + + let columns: Vec = vec![ + Arc::new(StringVector::from(vec![script])), + Arc::new(StringVector::from(vec![schema])), + Arc::new(StringVector::from(vec![name])), + ]; + + let schema = Arc::new(Schema::new(column_schemas)); + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + + let table = MemTable::table("scripts", recordbatch); + + let catalog_manager = MemoryCatalogManager::new_with_table(table.clone()); + + let factory = QueryEngineFactory::new(catalog_manager.clone(), None, false); + let query_engine = factory.query_engine(); + let mgr = ScriptManager::new(Arc::new(MockGrpcQueryHandler {}) as _, query_engine) + .await + .unwrap(); + mgr.insert_scripts_table(catalog, table); + + mgr +} + +struct MockGrpcQueryHandler {} + +#[async_trait] +impl GrpcQueryHandler for MockGrpcQueryHandler { + type Error = Error; + + async fn do_query(&self, _query: Request, _ctx: QueryContextRef) -> Result { + Ok(Output::AffectedRows(1)) + } +} diff --git a/src/servers/src/http/script.rs b/src/servers/src/http/script.rs index 7cf59862ae..ae60ea98c3 100644 --- a/src/servers/src/http/script.rs +++ b/src/servers/src/http/script.rs @@ -15,9 +15,11 @@ use std::collections::HashMap; use std::time::Instant; use axum::extract::{Json, Query, RawBody, State}; +use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_error::ext::ErrorExt; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use session::context::QueryContext; use crate::http::{ApiState, JsonResponse}; @@ -51,6 +53,9 @@ pub async fn scripts( RawBody(body): RawBody, ) -> Json { if let Some(script_handler) = &state.script_handler { + let catalog = params + .catalog + .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()); let schema = params.db.as_ref(); if schema.is_none() || schema.unwrap().is_empty() { @@ -67,8 +72,10 @@ pub async fn scripts( let script = unwrap_or_json_err!(String::from_utf8(bytes.to_vec())); + // Safety: schema and name are already checked above. + let query_ctx = QueryContext::with(&catalog, schema.unwrap()); let body = match script_handler - .insert_script(schema.unwrap(), name.unwrap(), &script) + .insert_script(query_ctx, name.unwrap(), &script) .await { Ok(()) => JsonResponse::with_output(None), @@ -83,6 +90,7 @@ pub async fn scripts( #[derive(Debug, Serialize, Deserialize, JsonSchema, Default)] pub struct ScriptQuery { + pub catalog: Option, pub db: Option, pub name: Option, #[serde(flatten)] @@ -96,6 +104,9 @@ pub async fn run_script( Query(params): Query, ) -> Json { if let Some(script_handler) = &state.script_handler { + let catalog = params + .catalog + .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()); let start = Instant::now(); let schema = params.db.as_ref(); @@ -109,10 +120,10 @@ pub async fn run_script( json_err!("invalid name"); } - // TODO(sunng87): query_context and db name resolution - + // Safety: schema and name are already checked above. + let query_ctx = QueryContext::with(&catalog, schema.unwrap()); let output = script_handler - .execute_script(schema.unwrap(), name.unwrap(), params.params) + .execute_script(query_ctx, name.unwrap(), params.params) .await; let resp = JsonResponse::from_output(vec![output]).await; diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index 796303a888..ef8f74575e 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -49,10 +49,15 @@ pub type ScriptHandlerRef = Arc; #[async_trait] pub trait ScriptHandler { - async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>; + async fn insert_script( + &self, + query_ctx: QueryContextRef, + name: &str, + script: &str, + ) -> Result<()>; async fn execute_script( &self, - schema: &str, + query_ctx: QueryContextRef, name: &str, params: HashMap, ) -> Result; diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 1755756a32..4ef156cb04 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -116,7 +116,15 @@ impl SqlQueryHandler for DummyInstance { #[async_trait] impl ScriptHandler for DummyInstance { - async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()> { + async fn insert_script( + &self, + query_ctx: QueryContextRef, + name: &str, + script: &str, + ) -> Result<()> { + let catalog = query_ctx.current_catalog(); + let schema = query_ctx.current_schema(); + let script = self .py_engine .compile(script, CompileContext::default()) @@ -127,18 +135,20 @@ impl ScriptHandler for DummyInstance { .scripts .write() .unwrap() - .insert(format!("{schema}_{name}"), Arc::new(script)); + .insert(format!("{catalog}_{schema}_{name}"), Arc::new(script)); Ok(()) } async fn execute_script( &self, - schema: &str, + query_ctx: QueryContextRef, name: &str, params: HashMap, ) -> Result { - let key = format!("{schema}_{name}"); + let catalog = query_ctx.current_catalog(); + let schema = query_ctx.current_schema(); + let key = format!("{catalog}_{schema}_{name}"); let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone(); diff --git a/src/servers/tests/py_script/mod.rs b/src/servers/tests/py_script/mod.rs index 49dcdd7aa4..bde643cb84 100644 --- a/src/servers/tests/py_script/mod.rs +++ b/src/servers/tests/py_script/mod.rs @@ -12,30 +12,79 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; +use std::sync::Arc; + +use common_query::Output; +use common_recordbatch::RecordBatch; +use datatypes::prelude::ConcreteDataType; +use datatypes::schema::{ColumnSchema, Schema}; +use datatypes::vectors::{StringVector, VectorRef}; use servers::error::Result; use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::ScriptHandler; -use session::context::QueryContext; +use session::context::QueryContextBuilder; use table::test_util::MemTable; use crate::create_testing_instance; #[tokio::test] async fn test_insert_py_udf_and_query() -> Result<()> { - let query_ctx = QueryContext::arc(); - let table = MemTable::default_numbers_table(); + let catalog = "greptime"; + let schema = "test"; + let name = "hello"; + let script = r#" +@copr(returns=['n']) +def hello() -> vector[str]: + return 'hello'; +"#; + + let column_schemas = vec![ + ColumnSchema::new("script", ConcreteDataType::string_datatype(), false), + ColumnSchema::new("schema", ConcreteDataType::string_datatype(), false), + ColumnSchema::new("name", ConcreteDataType::string_datatype(), false), + ]; + + let columns: Vec = vec![ + Arc::new(StringVector::from(vec![script])), + Arc::new(StringVector::from(vec![schema])), + Arc::new(StringVector::from(vec![name])), + ]; + + let raw_schema = Arc::new(Schema::new(column_schemas)); + let recordbatch = RecordBatch::new(raw_schema, columns).unwrap(); + + let table = MemTable::table("scripts", recordbatch); + + let query_ctx = QueryContextBuilder::default() + .current_catalog(catalog.to_string()) + .current_schema(schema.to_string()) + .build(); let instance = create_testing_instance(table); - let src = r#" -@coprocessor(args=["uint32s"], returns = ["ret"]) -def double_that(col) -> vector[u32]: - return col*2 - "#; instance - .insert_script("schema_test", "double_that", src) + .insert_script(query_ctx.clone(), name, script) .await?; + + let output = instance + .execute_script(query_ctx.clone(), name, HashMap::new()) + .await?; + + match output { + Output::RecordBatches(batches) => { + let expected = "\ ++-------+ +| n | ++-------+ +| hello | ++-------+"; + assert_eq!(expected, batches.pretty_print().unwrap()); + } + _ => unreachable!(), + } + let res = instance - .do_query("select double_that(uint32s) from numbers", query_ctx) + .do_query("select hello()", query_ctx) .await .remove(0) .unwrap(); @@ -46,12 +95,16 @@ def double_that(col) -> vector[u32]: } common_query::Output::Stream(s) => { let batches = common_recordbatch::util::collect_batches(s).await.unwrap(); - assert_eq!(batches.iter().count(), 1); - let first = batches.iter().next().unwrap(); - let col = first.column(0); - let val = col.get(1); - assert_eq!(val, datatypes::value::Value::UInt32(2)); + let expected = "\ ++---------+ +| hello() | ++---------+ +| hello | ++---------+"; + + assert_eq!(expected, batches.pretty_print().unwrap()); } } + Ok(()) } diff --git a/src/table/src/metadata.rs b/src/table/src/metadata.rs index ad1c7eb357..0f8880a14c 100644 --- a/src/table/src/metadata.rs +++ b/src/table/src/metadata.rs @@ -479,6 +479,10 @@ impl TableInfo { .map(|id| RegionId::new(self.table_id(), *id)) .collect() } + /// Returns the full table name in the form of `{catalog}.{schema}.{table}`. + pub fn full_table_name(&self) -> String { + common_catalog::format_full_table_name(&self.catalog_name, &self.schema_name, &self.name) + } } impl TableInfoBuilder {