diff --git a/Cargo.lock b/Cargo.lock index c14b4fd7f9..9e8af77dc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4876,7 +4876,7 @@ dependencies = [ [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=442348b2518c0bf187fb1ad011ba370c38b96cc4#442348b2518c0bf187fb1ad011ba370c38b96cc4" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=454c52634c3bac27de10bf0d85d5533eed1cf03f#454c52634c3bac27de10bf0d85d5533eed1cf03f" dependencies = [ "prost 0.13.5", "serde", diff --git a/Cargo.toml b/Cargo.toml index b85418efee..856d552246 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,7 +132,7 @@ etcd-client = "0.14" fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "442348b2518c0bf187fb1ad011ba370c38b96cc4" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "454c52634c3bac27de10bf0d85d5533eed1cf03f" } hex = "0.4" http = "1" humantime = "2.1" diff --git a/licenserc.toml b/licenserc.toml index 4fc00026c3..4a727b6b1b 100644 --- a/licenserc.toml +++ b/licenserc.toml @@ -27,6 +27,8 @@ excludes = [ "src/servers/src/repeated_field.rs", "src/servers/src/http/test_helpers.rs", # enterprise + "src/common/meta/src/rpc/ddl/trigger.rs", + "src/operator/src/expr_helper/trigger.rs", "src/sql/src/statements/create/trigger.rs", "src/sql/src/statements/show/trigger.rs", "src/sql/src/parsers/create_parser/trigger.rs", diff --git a/src/cmd/Cargo.toml b/src/cmd/Cargo.toml index 99611986c3..6c679fad87 100644 --- a/src/cmd/Cargo.toml +++ b/src/cmd/Cargo.toml @@ -16,6 +16,7 @@ default = [ "meta-srv/pg_kvbackend", "meta-srv/mysql_kvbackend", ] +enterprise = ["common-meta/enterprise", "frontend/enterprise", "meta-srv/enterprise"] tokio-console = ["common-telemetry/tokio-console"] [lints] diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index 80379c3ba7..3264b87449 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -35,6 +35,8 @@ use common_meta::ddl::flow_meta::{FlowMetadataAllocator, FlowMetadataAllocatorRe use common_meta::ddl::table_meta::{TableMetadataAllocator, TableMetadataAllocatorRef}; use common_meta::ddl::{DdlContext, NoopRegionFailureDetectorControl, ProcedureExecutorRef}; use common_meta::ddl_manager::DdlManager; +#[cfg(feature = "enterprise")] +use common_meta::ddl_manager::TriggerDdlManagerRef; use common_meta::key::flow::flow_state::FlowStat; use common_meta::key::flow::{FlowMetadataManager, FlowMetadataManagerRef}; use common_meta::key::{TableMetadataManager, TableMetadataManagerRef}; @@ -579,6 +581,8 @@ impl StartCommand { flow_id_sequence, )); + #[cfg(feature = "enterprise")] + let trigger_ddl_manager: Option = plugins.get(); let ddl_task_executor = Self::create_ddl_task_executor( procedure_manager.clone(), node_manager.clone(), @@ -587,6 +591,8 @@ impl StartCommand { table_meta_allocator, flow_metadata_manager, flow_meta_allocator, + #[cfg(feature = "enterprise")] + trigger_ddl_manager, ) .await?; @@ -651,6 +657,7 @@ impl StartCommand { }) } + #[allow(clippy::too_many_arguments)] pub async fn create_ddl_task_executor( procedure_manager: ProcedureManagerRef, node_manager: NodeManagerRef, @@ -659,6 +666,7 @@ impl StartCommand { table_metadata_allocator: TableMetadataAllocatorRef, flow_metadata_manager: FlowMetadataManagerRef, flow_metadata_allocator: FlowMetadataAllocatorRef, + #[cfg(feature = "enterprise")] trigger_ddl_manager: Option, ) -> Result { let procedure_executor: ProcedureExecutorRef = Arc::new( DdlManager::try_new( @@ -675,6 +683,8 @@ impl StartCommand { }, procedure_manager, true, + #[cfg(feature = "enterprise")] + trigger_ddl_manager, ) .context(error::InitDdlManagerSnafu)?, ); diff --git a/src/common/meta/Cargo.toml b/src/common/meta/Cargo.toml index 9c1351fa10..3cd26d7abb 100644 --- a/src/common/meta/Cargo.toml +++ b/src/common/meta/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true testing = [] pg_kvbackend = ["dep:tokio-postgres", "dep:backon", "dep:deadpool-postgres", "dep:deadpool"] mysql_kvbackend = ["dep:sqlx", "dep:backon"] +enterprise = [] [lints] workspace = true diff --git a/src/common/meta/src/ddl_manager.rs b/src/common/meta/src/ddl_manager.rs index c943c70bc9..abd739ae4a 100644 --- a/src/common/meta/src/ddl_manager.rs +++ b/src/common/meta/src/ddl_manager.rs @@ -47,6 +47,10 @@ use crate::error::{ use crate::key::table_info::TableInfoValue; use crate::key::table_name::TableNameKey; use crate::key::{DeserializedValueWithBytes, TableMetadataManagerRef}; +#[cfg(feature = "enterprise")] +use crate::rpc::ddl::trigger::CreateTriggerTask; +#[cfg(feature = "enterprise")] +use crate::rpc::ddl::DdlTask::CreateTrigger; use crate::rpc::ddl::DdlTask::{ AlterDatabase, AlterLogicalTables, AlterTable, CreateDatabase, CreateFlow, CreateLogicalTables, CreateTable, CreateView, DropDatabase, DropFlow, DropLogicalTables, DropTable, DropView, @@ -70,8 +74,29 @@ pub type BoxedProcedureLoaderFactory = dyn Fn(DdlContext) -> BoxedProcedureLoade pub struct DdlManager { ddl_context: DdlContext, procedure_manager: ProcedureManagerRef, + #[cfg(feature = "enterprise")] + trigger_ddl_manager: Option, } +/// This trait is responsible for handling DDL tasks about triggers. e.g., +/// create trigger, drop trigger, etc. +#[cfg(feature = "enterprise")] +#[async_trait::async_trait] +pub trait TriggerDdlManager: Send + Sync { + async fn create_trigger( + &self, + create_trigger_task: CreateTriggerTask, + procedure_manager: ProcedureManagerRef, + ddl_context: DdlContext, + query_context: QueryContext, + ) -> Result; + + fn as_any(&self) -> &dyn std::any::Any; +} + +#[cfg(feature = "enterprise")] +pub type TriggerDdlManagerRef = Arc; + macro_rules! procedure_loader_entry { ($procedure:ident) => { ( @@ -100,10 +125,13 @@ impl DdlManager { ddl_context: DdlContext, procedure_manager: ProcedureManagerRef, register_loaders: bool, + #[cfg(feature = "enterprise")] trigger_ddl_manager: Option, ) -> Result { let manager = Self { ddl_context, procedure_manager, + #[cfg(feature = "enterprise")] + trigger_ddl_manager, }; if register_loaders { manager.register_loaders()?; @@ -669,6 +697,28 @@ async fn handle_create_flow_task( }) } +#[cfg(feature = "enterprise")] +async fn handle_create_trigger_task( + ddl_manager: &DdlManager, + create_trigger_task: CreateTriggerTask, + query_context: QueryContext, +) -> Result { + let Some(m) = ddl_manager.trigger_ddl_manager.as_ref() else { + return UnsupportedSnafu { + operation: "create trigger", + } + .fail(); + }; + + m.create_trigger( + create_trigger_task, + ddl_manager.procedure_manager.clone(), + ddl_manager.ddl_context.clone(), + query_context, + ) + .await +} + async fn handle_alter_logical_table_tasks( ddl_manager: &DdlManager, alter_table_tasks: Vec, @@ -777,6 +827,15 @@ impl ProcedureExecutor for DdlManager { handle_create_flow_task(self, create_flow_task, request.query_context.into()) .await } + #[cfg(feature = "enterprise")] + CreateTrigger(create_trigger_task) => { + handle_create_trigger_task( + self, + create_trigger_task, + request.query_context.into(), + ) + .await + } DropFlow(drop_flow_task) => handle_drop_flow_task(self, drop_flow_task).await, CreateView(create_view_task) => { handle_create_view_task(self, create_view_task).await @@ -905,6 +964,8 @@ mod tests { }, procedure_manager.clone(), true, + #[cfg(feature = "enterprise")] + None, ); let expected_loaders = vec![ diff --git a/src/common/meta/src/rpc/ddl.rs b/src/common/meta/src/rpc/ddl.rs index 2797c6bee0..99eb24cebb 100644 --- a/src/common/meta/src/rpc/ddl.rs +++ b/src/common/meta/src/rpc/ddl.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[cfg(feature = "enterprise")] +pub mod trigger; + use std::collections::{HashMap, HashSet}; use std::result; @@ -68,6 +71,8 @@ pub enum DdlTask { DropFlow(DropFlowTask), CreateView(CreateViewTask), DropView(DropViewTask), + #[cfg(feature = "enterprise")] + CreateTrigger(trigger::CreateTriggerTask), } impl DdlTask { @@ -242,6 +247,18 @@ impl TryFrom for DdlTask { Task::DropFlowTask(drop_flow) => Ok(DdlTask::DropFlow(drop_flow.try_into()?)), Task::CreateViewTask(create_view) => Ok(DdlTask::CreateView(create_view.try_into()?)), Task::DropViewTask(drop_view) => Ok(DdlTask::DropView(drop_view.try_into()?)), + Task::CreateTriggerTask(create_trigger) => { + #[cfg(feature = "enterprise")] + return Ok(DdlTask::CreateTrigger(create_trigger.try_into()?)); + #[cfg(not(feature = "enterprise"))] + { + let _ = create_trigger; + crate::error::UnsupportedSnafu { + operation: "create trigger", + } + .fail() + } + } } } } @@ -292,6 +309,8 @@ impl TryFrom for PbDdlTaskRequest { DdlTask::DropFlow(task) => Task::DropFlowTask(task.into()), DdlTask::CreateView(task) => Task::CreateViewTask(task.try_into()?), DdlTask::DropView(task) => Task::DropViewTask(task.into()), + #[cfg(feature = "enterprise")] + DdlTask::CreateTrigger(task) => Task::CreateTriggerTask(task.into()), }; Ok(Self { diff --git a/src/common/meta/src/rpc/ddl/trigger.rs b/src/common/meta/src/rpc/ddl/trigger.rs new file mode 100644 index 0000000000..8d957efc2a --- /dev/null +++ b/src/common/meta/src/rpc/ddl/trigger.rs @@ -0,0 +1,276 @@ +use std::collections::HashMap; +use std::time::Duration; + +use api::v1::meta::CreateTriggerTask as PbCreateTriggerTask; +use api::v1::notify_channel::ChannelType as PbChannelType; +use api::v1::{ + CreateTriggerExpr, NotifyChannel as PbNotifyChannel, WebhookOptions as PbWebhookOptions, +}; +use serde::{Deserialize, Serialize}; +use snafu::OptionExt; + +use crate::error; +use crate::error::Result; +use crate::rpc::ddl::DdlTask; + +// Create trigger +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateTriggerTask { + pub catalog_name: String, + pub trigger_name: String, + pub if_not_exists: bool, + pub sql: String, + pub channels: Vec, + pub labels: HashMap, + pub annotations: HashMap, + pub interval: Duration, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct NotifyChannel { + pub name: String, + pub channel_type: ChannelType, +} + +/// The available channel enum for sending trigger notifications. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ChannelType { + Webhook(WebhookOptions), +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct WebhookOptions { + /// The URL of the AlertManager API endpoint. + /// + /// e.g., "http://localhost:9093". + pub url: String, + /// Configuration options for the AlertManager webhook. e.g., timeout, etc. + pub opts: HashMap, +} + +impl From for PbCreateTriggerTask { + fn from(task: CreateTriggerTask) -> Self { + let channels = task + .channels + .into_iter() + .map(PbNotifyChannel::from) + .collect(); + + let expr = CreateTriggerExpr { + catalog_name: task.catalog_name, + trigger_name: task.trigger_name, + create_if_not_exists: task.if_not_exists, + sql: task.sql, + channels, + labels: task.labels, + annotations: task.annotations, + interval: task.interval.as_secs(), + }; + + PbCreateTriggerTask { + create_trigger: Some(expr), + } + } +} + +impl TryFrom for CreateTriggerTask { + type Error = error::Error; + + fn try_from(task: PbCreateTriggerTask) -> Result { + let expr = task.create_trigger.context(error::InvalidProtoMsgSnafu { + err_msg: "expected create_trigger", + })?; + + let channels = expr + .channels + .into_iter() + .map(NotifyChannel::try_from) + .collect::>>()?; + + let task = CreateTriggerTask { + catalog_name: expr.catalog_name, + trigger_name: expr.trigger_name, + if_not_exists: expr.create_if_not_exists, + sql: expr.sql, + channels, + labels: expr.labels, + annotations: expr.annotations, + interval: Duration::from_secs(expr.interval), + }; + Ok(task) + } +} + +impl From for PbNotifyChannel { + fn from(channel: NotifyChannel) -> Self { + let NotifyChannel { name, channel_type } = channel; + + let channel_type = match channel_type { + ChannelType::Webhook(options) => PbChannelType::Webhook(PbWebhookOptions { + url: options.url, + opts: options.opts, + }), + }; + + PbNotifyChannel { + name, + channel_type: Some(channel_type), + } + } +} + +impl TryFrom for NotifyChannel { + type Error = error::Error; + + fn try_from(channel: PbNotifyChannel) -> Result { + let PbNotifyChannel { name, channel_type } = channel; + + let channel_type = channel_type.context(error::InvalidProtoMsgSnafu { + err_msg: "expected channel_type", + })?; + + let channel_type = match channel_type { + PbChannelType::Webhook(options) => ChannelType::Webhook(WebhookOptions { + url: options.url, + opts: options.opts, + }), + }; + Ok(NotifyChannel { name, channel_type }) + } +} + +impl DdlTask { + /// Creates a [`DdlTask`] to create a trigger. + pub fn new_create_trigger(expr: CreateTriggerTask) -> Self { + DdlTask::CreateTrigger(expr) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_convert_create_trigger_task() { + let original = CreateTriggerTask { + catalog_name: "test_catalog".to_string(), + trigger_name: "test_trigger".to_string(), + if_not_exists: true, + sql: "SELECT * FROM test".to_string(), + channels: vec![ + NotifyChannel { + name: "channel1".to_string(), + channel_type: ChannelType::Webhook(WebhookOptions { + url: "http://localhost:9093".to_string(), + opts: HashMap::from([("timeout".to_string(), "30s".to_string())]), + }), + }, + NotifyChannel { + name: "channel2".to_string(), + channel_type: ChannelType::Webhook(WebhookOptions { + url: "http://alertmanager:9093".to_string(), + opts: HashMap::new(), + }), + }, + ], + labels: vec![ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ] + .into_iter() + .collect(), + annotations: vec![ + ("summary".to_string(), "Test alert".to_string()), + ("description".to_string(), "This is a test".to_string()), + ] + .into_iter() + .collect(), + interval: Duration::from_secs(60), + }; + + let pb_task: PbCreateTriggerTask = original.clone().into(); + + let expr = pb_task.create_trigger.as_ref().unwrap(); + assert_eq!(expr.catalog_name, "test_catalog"); + assert_eq!(expr.trigger_name, "test_trigger"); + assert!(expr.create_if_not_exists); + assert_eq!(expr.sql, "SELECT * FROM test"); + assert_eq!(expr.channels.len(), 2); + assert_eq!(expr.labels.len(), 2); + assert_eq!(expr.labels.get("key1").unwrap(), "value1"); + assert_eq!(expr.labels.get("key2").unwrap(), "value2"); + assert_eq!(expr.annotations.len(), 2); + assert_eq!(expr.annotations.get("summary").unwrap(), "Test alert"); + assert_eq!( + expr.annotations.get("description").unwrap(), + "This is a test" + ); + assert_eq!(expr.interval, 60); + + let round_tripped = CreateTriggerTask::try_from(pb_task).unwrap(); + + assert_eq!(original.catalog_name, round_tripped.catalog_name); + assert_eq!(original.trigger_name, round_tripped.trigger_name); + assert_eq!(original.if_not_exists, round_tripped.if_not_exists); + assert_eq!(original.sql, round_tripped.sql); + assert_eq!(original.channels.len(), round_tripped.channels.len()); + assert_eq!(&original.channels[0], &round_tripped.channels[0]); + assert_eq!(&original.channels[1], &round_tripped.channels[1]); + assert_eq!(original.labels, round_tripped.labels); + assert_eq!(original.annotations, round_tripped.annotations); + assert_eq!(original.interval, round_tripped.interval); + + // Invalid, since create_trigger is None and it's required. + let invalid_task = PbCreateTriggerTask { + create_trigger: None, + }; + let result = CreateTriggerTask::try_from(invalid_task); + assert!(result.is_err()); + } + + #[test] + fn test_convert_notify_channel() { + let original = NotifyChannel { + name: "test_channel".to_string(), + channel_type: ChannelType::Webhook(WebhookOptions { + url: "http://localhost:9093".to_string(), + opts: HashMap::new(), + }), + }; + let pb_channel: PbNotifyChannel = original.clone().into(); + match pb_channel.channel_type.as_ref().unwrap() { + PbChannelType::Webhook(options) => { + assert_eq!(pb_channel.name, "test_channel"); + assert_eq!(options.url, "http://localhost:9093"); + assert!(options.opts.is_empty()); + } + } + let round_tripped = NotifyChannel::try_from(pb_channel).unwrap(); + assert_eq!(original, round_tripped); + + // Test with timeout is None. + let no_timeout = NotifyChannel { + name: "no_timeout".to_string(), + channel_type: ChannelType::Webhook(WebhookOptions { + url: "http://localhost:9093".to_string(), + opts: HashMap::new(), + }), + }; + let pb_no_timeout: PbNotifyChannel = no_timeout.clone().into(); + match pb_no_timeout.channel_type.as_ref().unwrap() { + PbChannelType::Webhook(options) => { + assert_eq!(options.url, "http://localhost:9093"); + } + } + let round_tripped_no_timeout = NotifyChannel::try_from(pb_no_timeout).unwrap(); + assert_eq!(no_timeout, round_tripped_no_timeout); + + // Invalid, since channel_type is None and it's required. + let invalid_channel = PbNotifyChannel { + name: "invalid".to_string(), + channel_type: None, + }; + let result = NotifyChannel::try_from(invalid_channel); + assert!(result.is_err()); + } +} diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 9d54845f21..47cd0d4975 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -6,7 +6,7 @@ license.workspace = true [features] testing = [] -enterprise = ["operator/enterprise", "sql/enterprise"] +enterprise = ["common-meta/enterprise", "operator/enterprise", "sql/enterprise"] [lints] workspace = true diff --git a/src/meta-srv/Cargo.toml b/src/meta-srv/Cargo.toml index 6c61ffd546..c2b042059e 100644 --- a/src/meta-srv/Cargo.toml +++ b/src/meta-srv/Cargo.toml @@ -9,6 +9,7 @@ mock = [] pg_kvbackend = ["dep:tokio-postgres", "common-meta/pg_kvbackend", "dep:deadpool-postgres", "dep:deadpool"] mysql_kvbackend = ["dep:sqlx", "common-meta/mysql_kvbackend"] testing = ["common-wal/testing"] +enterprise = ["common-meta/enterprise"] [lints] workspace = true diff --git a/src/meta-srv/src/metasrv/builder.rs b/src/meta-srv/src/metasrv/builder.rs index fd7fc440b4..902e23fc17 100644 --- a/src/meta-srv/src/metasrv/builder.rs +++ b/src/meta-srv/src/metasrv/builder.rs @@ -280,7 +280,7 @@ impl MetasrvBuilder { ensure!( options.allow_region_failover_on_local_wal, error::UnexpectedSnafu { - violated: "Region failover is not supported in the local WAL implementation! + violated: "Region failover is not supported in the local WAL implementation! If you want to enable region failover for local WAL, please set `allow_region_failover_on_local_wal` to true.", } ); @@ -351,6 +351,11 @@ impl MetasrvBuilder { }; let leader_region_registry = Arc::new(LeaderRegionRegistry::default()); + + #[cfg(feature = "enterprise")] + let trigger_ddl_manager = plugins + .as_ref() + .and_then(|plugins| plugins.get::()); let ddl_manager = Arc::new( DdlManager::try_new( DdlContext { @@ -366,6 +371,8 @@ impl MetasrvBuilder { }, procedure_manager.clone(), true, + #[cfg(feature = "enterprise")] + trigger_ddl_manager, ) .context(error::InitDdlManagerSnafu)?, ); diff --git a/src/operator/Cargo.toml b/src/operator/Cargo.toml index 8ce93ff827..c0d7f5fd17 100644 --- a/src/operator/Cargo.toml +++ b/src/operator/Cargo.toml @@ -6,7 +6,7 @@ license.workspace = true [features] testing = [] -enterprise = ["sql/enterprise"] +enterprise = ["common-meta/enterprise", "sql/enterprise"] [lints] workspace = true diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index 36b8cda685..1242138539 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -703,6 +703,14 @@ pub enum Error { location: Location, }, + #[cfg(feature = "enterprise")] + #[snafu(display("Invalid trigger name: {name}"))] + InvalidTriggerName { + name: String, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("Empty {} expr", name))] EmptyDdlExpr { name: String, @@ -872,6 +880,8 @@ impl ErrorExt for Error { | Error::CursorNotFound { .. } | Error::CursorExists { .. } | Error::CreatePartitionRules { .. } => StatusCode::InvalidArguments, + #[cfg(feature = "enterprise")] + Error::InvalidTriggerName { .. } => StatusCode::InvalidArguments, Error::TableAlreadyExists { .. } | Error::ViewAlreadyExists { .. } => { StatusCode::TableAlreadyExists } diff --git a/src/operator/src/expr_helper.rs b/src/operator/src/expr_helper.rs index fe5689042f..eff0ec5555 100644 --- a/src/operator/src/expr_helper.rs +++ b/src/operator/src/expr_helper.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[cfg(feature = "enterprise")] +pub mod trigger; + use std::collections::{HashMap, HashSet}; use api::helper::ColumnDataTypeWrapper; @@ -55,6 +58,8 @@ use sql::statements::{ use sql::util::extract_tables_from_query; use table::requests::{TableOptions, FILE_TABLE_META_KEY}; use table::table_reference::TableReference; +#[cfg(feature = "enterprise")] +pub use trigger::to_create_trigger_task_expr; use crate::error::{ BuildCreateExprOnInsertionSnafu, ColumnDataTypeSnafu, ConvertColumnDefaultConstraintSnafu, diff --git a/src/operator/src/expr_helper/trigger.rs b/src/operator/src/expr_helper/trigger.rs new file mode 100644 index 0000000000..37ed64d6ee --- /dev/null +++ b/src/operator/src/expr_helper/trigger.rs @@ -0,0 +1,146 @@ +use api::v1::notify_channel::ChannelType as PbChannelType; +use api::v1::{ + CreateTriggerExpr as PbCreateTriggerExpr, NotifyChannel as PbNotifyChannel, + WebhookOptions as PbWebhookOptions, +}; +use session::context::QueryContextRef; +use snafu::ensure; +use sql::ast::ObjectName; +use sql::statements::create::trigger::{ChannelType, CreateTrigger}; + +use crate::error::Result; + +pub fn to_create_trigger_task_expr( + create_trigger: CreateTrigger, + query_ctx: &QueryContextRef, +) -> Result { + let CreateTrigger { + trigger_name, + if_not_exists, + query, + interval, + labels, + annotations, + channels, + } = create_trigger; + + let catalog_name = query_ctx.current_catalog().to_string(); + let trigger_name = sanitize_trigger_name(trigger_name)?; + + let channels = channels + .into_iter() + .map(|c| { + let name = c.name.value; + match c.channel_type { + ChannelType::Webhook(am) => PbNotifyChannel { + name, + channel_type: Some(PbChannelType::Webhook(PbWebhookOptions { + url: am.url.value, + opts: am.options.into_map(), + })), + }, + } + }) + .collect::>(); + + let sql = query.to_string(); + let labels = labels.into_map(); + let annotations = annotations.into_map(); + + Ok(PbCreateTriggerExpr { + catalog_name, + trigger_name, + create_if_not_exists: if_not_exists, + sql, + channels, + labels, + annotations, + interval, + }) +} + +fn sanitize_trigger_name(mut trigger_name: ObjectName) -> Result { + ensure!( + trigger_name.0.len() == 1, + crate::error::InvalidTriggerNameSnafu { + name: trigger_name.to_string(), + } + ); + // safety: we've checked trigger_name.0 has exactly one element. + Ok(trigger_name.0.swap_remove(0).value) +} + +#[cfg(test)] +mod tests { + use session::context::QueryContext; + use sql::dialect::GreptimeDbDialect; + use sql::parser::{ParseOptions, ParserContext}; + use sql::statements::statement::Statement; + + use super::*; + + #[test] + fn test_sanitize_trigger_name() { + let name = ObjectName(vec![sql::ast::Ident::new("my_trigger")]); + let sanitized = sanitize_trigger_name(name).unwrap(); + assert_eq!(sanitized, "my_trigger"); + + let name = ObjectName(vec![sql::ast::Ident::with_quote('`', "my_trigger")]); + let sanitized = sanitize_trigger_name(name).unwrap(); + assert_eq!(sanitized, "my_trigger"); + + let name = ObjectName(vec![sql::ast::Ident::with_quote('\'', "trigger")]); + let sanitized = sanitize_trigger_name(name).unwrap(); + assert_eq!(sanitized, "trigger"); + } + + #[test] + fn test_to_create_trigger_task_expr() { + let sql = r#"CREATE TRIGGER IF NOT EXISTS cpu_monitor +ON (SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2) EVERY '5 minute'::INTERVAL +LABELS (label_name=label_val) +ANNOTATIONS (annotation_name=annotation_val) +NOTIFY +(WEBHOOK alert_manager URL 'http://127.0.0.1:9093' WITH (timeout='1m'))"#; + + let stmt = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap() + .pop() + .unwrap(); + + let Statement::CreateTrigger(stmt) = stmt else { + unreachable!() + }; + + let query_ctx = QueryContext::arc(); + let expr = to_create_trigger_task_expr(stmt, &query_ctx).unwrap(); + + assert_eq!("greptime", expr.catalog_name); + assert_eq!("cpu_monitor", expr.trigger_name); + assert!(expr.create_if_not_exists); + assert_eq!( + "(SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2)", + expr.sql + ); + assert_eq!(300, expr.interval); + assert_eq!(1, expr.labels.len()); + assert_eq!("label_val", expr.labels.get("label_name").unwrap()); + assert_eq!(1, expr.annotations.len()); + assert_eq!( + "annotation_val", + expr.annotations.get("annotation_name").unwrap() + ); + assert_eq!(1, expr.channels.len()); + let c = &expr.channels[0]; + assert_eq!("alert_manager", c.name,); + let channel_type = c.channel_type.as_ref().unwrap(); + let PbChannelType::Webhook(am) = &channel_type; + assert_eq!("http://127.0.0.1:9093", am.url); + assert_eq!(1, am.opts.len()); + assert_eq!( + "1m", + am.opts.get("timeout").expect("Expected timeout option") + ); + } +} diff --git a/src/operator/src/statement/ddl.rs b/src/operator/src/statement/ddl.rs index ccaa1bfa89..f02467f1a2 100644 --- a/src/operator/src/statement/ddl.rs +++ b/src/operator/src/statement/ddl.rs @@ -20,6 +20,10 @@ use api::v1::meta::CreateFlowTask as PbCreateFlowTask; use api::v1::{ column_def, AlterDatabaseExpr, AlterTableExpr, CreateFlowExpr, CreateTableExpr, CreateViewExpr, }; +#[cfg(feature = "enterprise")] +use api::v1::{ + meta::CreateTriggerTask as PbCreateTriggerTask, CreateTriggerExpr as PbCreateTriggerExpr, +}; use catalog::CatalogManagerRef; use chrono::Utc; use common_catalog::consts::{is_readonly_schema, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; @@ -31,6 +35,8 @@ use common_meta::ddl::ExecutorContext; use common_meta::instruction::CacheIdent; use common_meta::key::schema_name::{SchemaName, SchemaNameKey}; use common_meta::key::NAME_PATTERN; +#[cfg(feature = "enterprise")] +use common_meta::rpc::ddl::trigger::CreateTriggerTask; use common_meta::rpc::ddl::{ CreateFlowTask, DdlTask, DropFlowTask, DropViewTask, SubmitDdlTaskRequest, SubmitDdlTaskResponse, @@ -58,6 +64,8 @@ use session::table_name::table_idents_to_full_name; use snafu::{ensure, OptionExt, ResultExt}; use sql::parser::{ParseOptions, ParserContext}; use sql::statements::alter::{AlterDatabase, AlterTable}; +#[cfg(feature = "enterprise")] +use sql::statements::create::trigger::CreateTrigger; use sql::statements::create::{ CreateExternalTable, CreateFlow, CreateTable, CreateTableLike, CreateView, Partitions, }; @@ -347,10 +355,43 @@ impl StatementExecutor { #[tracing::instrument(skip_all)] pub async fn create_trigger( &self, - _stmt: sql::statements::create::trigger::CreateTrigger, - _query_context: QueryContextRef, + stmt: CreateTrigger, + query_context: QueryContextRef, ) -> Result { - crate::error::UnsupportedTriggerSnafu {}.fail() + let expr = expr_helper::to_create_trigger_task_expr(stmt, &query_context)?; + self.create_trigger_inner(expr, query_context).await + } + + #[cfg(feature = "enterprise")] + pub async fn create_trigger_inner( + &self, + expr: PbCreateTriggerExpr, + query_context: QueryContextRef, + ) -> Result { + self.create_trigger_procedure(expr, query_context).await?; + Ok(Output::new_with_affected_rows(0)) + } + + #[cfg(feature = "enterprise")] + async fn create_trigger_procedure( + &self, + expr: PbCreateTriggerExpr, + query_context: QueryContextRef, + ) -> Result { + let task = CreateTriggerTask::try_from(PbCreateTriggerTask { + create_trigger: Some(expr), + }) + .context(error::InvalidExprSnafu)?; + + let request = SubmitDdlTaskRequest { + query_context, + task: DdlTask::new_create_trigger(task), + }; + + self.procedure_executor + .submit_ddl_task(&ExecutorContext::default(), request) + .await + .context(error::ExecuteDdlSnafu) } #[tracing::instrument(skip_all)] diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index f2f538f528..f576206cc0 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -6,6 +6,14 @@ license.workspace = true [features] dashboard = [] +enterprise = [ + "cmd/enterprise", + "common-meta/enterprise", + "frontend/enterprise", + "meta-srv/enterprise", + "operator/enterprise", + "sql/enterprise", +] [lints] workspace = true diff --git a/tests-integration/src/standalone.rs b/tests-integration/src/standalone.rs index 7f266f56bd..25159ca9e6 100644 --- a/tests-integration/src/standalone.rs +++ b/tests-integration/src/standalone.rs @@ -226,6 +226,8 @@ impl GreptimeDbStandaloneBuilder { }, procedure_manager.clone(), register_procedure_loaders, + #[cfg(feature = "enterprise")] + None, ) .unwrap(), );