diff --git a/Cargo.lock b/Cargo.lock index 1c5bb2729d..59c941f45d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2401,6 +2401,7 @@ dependencies = [ "object-store", "prometheus", "prost 0.13.5", + "prost-types 0.13.5", "rand 0.9.1", "regex", "rskafka", @@ -5318,9 +5319,10 @@ dependencies = [ [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=03007c30a2d2bf1acb4374cf5e92df9b0bd8844e#03007c30a2d2bf1acb4374cf5e92df9b0bd8844e" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=df2bb74b5990c159dfd5b7a344eecf8f4307af64#df2bb74b5990c159dfd5b7a344eecf8f4307af64" dependencies = [ "prost 0.13.5", + "prost-types 0.13.5", "serde", "serde_json", "strum 0.25.0", @@ -8730,6 +8732,7 @@ dependencies = [ "path-slash", "prometheus", "prost 0.13.5", + "prost-types 0.13.5", "query", "regex", "serde_json", @@ -12101,6 +12104,7 @@ name = "sql" version = "0.17.0" dependencies = [ "api", + "arrow-buffer", "chrono", "common-base", "common-catalog", diff --git a/Cargo.toml b/Cargo.toml index 0edf6cec00..a525b8b0e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ ahash = { version = "0.8", features = ["compile-time-rng"] } aquamarine = "0.6" arrow = { version = "56.0", features = ["prettyprint"] } arrow-array = { version = "56.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = "56.0" arrow-flight = "56.0" arrow-ipc = { version = "56.0", default-features = false, features = ["lz4", "zstd"] } arrow-schema = { version = "56.0", features = ["serde"] } @@ -141,7 +142,7 @@ etcd-client = { git = "https://github.com/GreptimeTeam/etcd-client", rev = "f62d fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "03007c30a2d2bf1acb4374cf5e92df9b0bd8844e" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "df2bb74b5990c159dfd5b7a344eecf8f4307af64" } hex = "0.4" http = "1" humantime = "2.1" @@ -178,6 +179,7 @@ pretty_assertions = "1.4.0" prometheus = { version = "0.13.3", features = ["process"] } promql-parser = { version = "0.6", features = ["ser"] } prost = { version = "0.13", features = ["no-recursion-limit"] } +prost-types = "0.13" raft-engine = { version = "0.4.1", default-features = false } rand = "0.9" ratelimit = "0.10" diff --git a/docker/dev-builder/centos/Dockerfile b/docker/dev-builder/centos/Dockerfile index bcbf5d9570..58c9baeaa1 100644 --- a/docker/dev-builder/centos/Dockerfile +++ b/docker/dev-builder/centos/Dockerfile @@ -19,7 +19,7 @@ ARG PROTOBUF_VERSION=29.3 RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip && \ unzip protoc-${PROTOBUF_VERSION}-linux-x86_64.zip -d protoc3; - + RUN mv protoc3/bin/* /usr/local/bin/ RUN mv protoc3/include/* /usr/local/include/ diff --git a/src/common/meta/Cargo.toml b/src/common/meta/Cargo.toml index 216ca4f4d6..07d36c6c46 100644 --- a/src/common/meta/Cargo.toml +++ b/src/common/meta/Cargo.toml @@ -17,7 +17,7 @@ pg_kvbackend = [ "dep:rustls", ] mysql_kvbackend = ["dep:sqlx", "dep:backon"] -enterprise = [] +enterprise = ["prost-types"] [lints] workspace = true @@ -64,6 +64,7 @@ moka.workspace = true object-store.workspace = true prometheus.workspace = true prost.workspace = true +prost-types = { workspace = true, optional = true } rand.workspace = true regex.workspace = true rskafka.workspace = true diff --git a/src/common/meta/src/error.rs b/src/common/meta/src/error.rs index ff50bd0d7e..86cb64e466 100644 --- a/src/common/meta/src/error.rs +++ b/src/common/meta/src/error.rs @@ -1027,6 +1027,31 @@ pub enum Error { actual_column_name: String, actual_column_id: u32, }, + + #[cfg(feature = "enterprise")] + #[snafu(display("Too large duration"))] + TooLargeDuration { + #[snafu(source)] + error: prost_types::DurationError, + #[snafu(implicit)] + location: Location, + }, + + #[cfg(feature = "enterprise")] + #[snafu(display("Negative duration"))] + NegativeDuration { + #[snafu(source)] + error: prost_types::DurationError, + #[snafu(implicit)] + location: Location, + }, + + #[cfg(feature = "enterprise")] + #[snafu(display("Missing interval field"))] + MissingInterval { + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -1116,8 +1141,13 @@ impl ErrorExt for Error { | InvalidTimeZone { .. } | InvalidFileExtension { .. } | InvalidFileName { .. } + | InvalidFlowRequestBody { .. } | InvalidFilePath { .. } => StatusCode::InvalidArguments, - InvalidFlowRequestBody { .. } => StatusCode::InvalidArguments, + + #[cfg(feature = "enterprise")] + MissingInterval { .. } | NegativeDuration { .. } | TooLargeDuration { .. } => { + StatusCode::InvalidArguments + } FlowNotFound { .. } => StatusCode::FlowNotFound, FlowRouteNotFound { .. } => StatusCode::Unexpected, diff --git a/src/common/meta/src/rpc/ddl.rs b/src/common/meta/src/rpc/ddl.rs index 6c1c2168f0..a6ad4b5109 100644 --- a/src/common/meta/src/rpc/ddl.rs +++ b/src/common/meta/src/rpc/ddl.rs @@ -328,7 +328,7 @@ impl TryFrom for PbDdlTaskRequest { DdlTask::CreateView(task) => Task::CreateViewTask(task.try_into()?), DdlTask::DropView(task) => Task::DropViewTask(task.into()), #[cfg(feature = "enterprise")] - DdlTask::CreateTrigger(task) => Task::CreateTriggerTask(task.into()), + DdlTask::CreateTrigger(task) => Task::CreateTriggerTask(task.try_into()?), #[cfg(feature = "enterprise")] DdlTask::DropTrigger(task) => Task::DropTriggerTask(task.into()), }; diff --git a/src/common/meta/src/rpc/ddl/trigger.rs b/src/common/meta/src/rpc/ddl/trigger.rs index 28bc49d108..c231566cf3 100644 --- a/src/common/meta/src/rpc/ddl/trigger.rs +++ b/src/common/meta/src/rpc/ddl/trigger.rs @@ -10,7 +10,7 @@ use api::v1::{ NotifyChannel as PbNotifyChannel, WebhookOptions as PbWebhookOptions, }; use serde::{Deserialize, Serialize}; -use snafu::OptionExt; +use snafu::{OptionExt, ResultExt}; use crate::error; use crate::error::Result; @@ -27,6 +27,7 @@ pub struct CreateTriggerTask { pub labels: HashMap, pub annotations: HashMap, pub interval: Duration, + pub raw_interval_expr: String, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -51,14 +52,21 @@ pub struct WebhookOptions { pub opts: HashMap, } -impl From for PbCreateTriggerTask { - fn from(task: CreateTriggerTask) -> Self { +impl TryFrom for PbCreateTriggerTask { + type Error = error::Error; + + fn try_from(task: CreateTriggerTask) -> Result { let channels = task .channels .into_iter() .map(PbNotifyChannel::from) .collect(); + let interval = task + .interval + .try_into() + .context(error::TooLargeDurationSnafu)?; + let expr = PbCreateTriggerExpr { catalog_name: task.catalog_name, trigger_name: task.trigger_name, @@ -67,12 +75,13 @@ impl From for PbCreateTriggerTask { channels, labels: task.labels, annotations: task.annotations, - interval: task.interval.as_secs(), + interval: Some(interval), + raw_interval_expr: task.raw_interval_expr, }; - PbCreateTriggerTask { + Ok(PbCreateTriggerTask { create_trigger: Some(expr), - } + }) } } @@ -90,6 +99,9 @@ impl TryFrom for CreateTriggerTask { .map(NotifyChannel::try_from) .collect::>>()?; + let interval = expr.interval.context(error::MissingIntervalSnafu)?; + let interval = interval.try_into().context(error::NegativeDurationSnafu)?; + let task = CreateTriggerTask { catalog_name: expr.catalog_name, trigger_name: expr.trigger_name, @@ -98,7 +110,8 @@ impl TryFrom for CreateTriggerTask { channels, labels: expr.labels, annotations: expr.annotations, - interval: Duration::from_secs(expr.interval), + interval, + raw_interval_expr: expr.raw_interval_expr, }; Ok(task) } @@ -258,9 +271,10 @@ mod tests { .into_iter() .collect(), interval: Duration::from_secs(60), + raw_interval_expr: "'1 minute'::INTERVAL".to_string(), }; - let pb_task: PbCreateTriggerTask = original.clone().into(); + let pb_task: PbCreateTriggerTask = original.clone().try_into().unwrap(); let expr = pb_task.create_trigger.as_ref().unwrap(); assert_eq!(expr.catalog_name, "test_catalog"); @@ -277,7 +291,8 @@ mod tests { expr.annotations.get("description").unwrap(), "This is a test" ); - assert_eq!(expr.interval, 60); + let expected: prost_types::Duration = Duration::from_secs(60).try_into().unwrap(); + assert_eq!(expr.interval, Some(expected)); let round_tripped = CreateTriggerTask::try_from(pb_task).unwrap(); diff --git a/src/operator/Cargo.toml b/src/operator/Cargo.toml index f809f3a6f4..c4c43956d0 100644 --- a/src/operator/Cargo.toml +++ b/src/operator/Cargo.toml @@ -6,7 +6,7 @@ license.workspace = true [features] testing = [] -enterprise = ["common-meta/enterprise", "sql/enterprise", "query/enterprise"] +enterprise = ["common-meta/enterprise", "sql/enterprise", "query/enterprise", "prost-types"] [lints] workspace = true @@ -56,6 +56,7 @@ object_store_opendal.workspace = true partition.workspace = true prometheus.workspace = true prost.workspace = true +prost-types = { workspace = true, optional = true } query.workspace = true regex.workspace = true serde_json.workspace = true diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index 0c96a419de..9ab849292a 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -851,6 +851,15 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[cfg(feature = "enterprise")] + #[snafu(display("Too large duration"))] + TooLargeDuration { + #[snafu(source)] + error: prost_types::DurationError, + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -897,6 +906,8 @@ impl ErrorExt for Error { | Error::CreatePartitionRules { .. } => StatusCode::InvalidArguments, #[cfg(feature = "enterprise")] Error::InvalidTriggerName { .. } => StatusCode::InvalidArguments, + #[cfg(feature = "enterprise")] + Error::TooLargeDuration { .. } => StatusCode::InvalidArguments, Error::TableAlreadyExists { .. } | Error::ViewAlreadyExists { .. } => { StatusCode::TableAlreadyExists } diff --git a/src/operator/src/expr_helper/trigger.rs b/src/operator/src/expr_helper/trigger.rs index 041f9018ee..c6e7e5cead 100644 --- a/src/operator/src/expr_helper/trigger.rs +++ b/src/operator/src/expr_helper/trigger.rs @@ -4,10 +4,11 @@ use api::v1::{ WebhookOptions as PbWebhookOptions, }; use session::context::QueryContextRef; -use snafu::ensure; +use snafu::{ensure, ResultExt}; use sql::ast::{ObjectName, ObjectNamePartExt}; -use sql::statements::create::trigger::{ChannelType, CreateTrigger}; +use sql::statements::create::trigger::{ChannelType, CreateTrigger, TriggerOn}; +use crate::error; use crate::error::Result; pub fn to_create_trigger_task_expr( @@ -17,13 +18,18 @@ pub fn to_create_trigger_task_expr( let CreateTrigger { trigger_name, if_not_exists, - query, - interval, + trigger_on, labels, annotations, channels, } = create_trigger; + let TriggerOn { + query, + interval, + raw_interval_expr, + } = trigger_on; + let catalog_name = query_ctx.current_catalog().to_string(); let trigger_name = sanitize_trigger_name(trigger_name)?; @@ -47,6 +53,8 @@ pub fn to_create_trigger_task_expr( let labels = labels.into_map(); let annotations = annotations.into_map(); + let interval = interval.try_into().context(error::TooLargeDurationSnafu)?; + Ok(PbCreateTriggerExpr { catalog_name, trigger_name, @@ -55,7 +63,8 @@ pub fn to_create_trigger_task_expr( channels, labels, annotations, - interval, + interval: Some(interval), + raw_interval_expr, }) } @@ -72,6 +81,8 @@ fn sanitize_trigger_name(mut trigger_name: ObjectName) -> Result { #[cfg(test)] mod tests { + use std::time::Duration; + use session::context::QueryContext; use sql::dialect::GreptimeDbDialect; use sql::parser::{ParseOptions, ParserContext}; @@ -123,7 +134,8 @@ NOTIFY "(SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2)", expr.sql ); - assert_eq!(300, expr.interval); + let expected: prost_types::Duration = Duration::from_secs(300).try_into().unwrap(); + assert_eq!(Some(expected), expr.interval); assert_eq!(1, expr.labels.len()); assert_eq!("label_val", expr.labels.get("label_name").unwrap()); assert_eq!(1, expr.annotations.len()); diff --git a/src/sql/Cargo.toml b/src/sql/Cargo.toml index 9fbae28893..8459bb375d 100644 --- a/src/sql/Cargo.toml +++ b/src/sql/Cargo.toml @@ -12,6 +12,7 @@ enterprise = [] [dependencies] api.workspace = true +arrow-buffer.workspace = true chrono.workspace = true common-base.workspace = true common-catalog.workspace = true diff --git a/src/sql/src/error.rs b/src/sql/src/error.rs index c840311591..1748d3b9de 100644 --- a/src/sql/src/error.rs +++ b/src/sql/src/error.rs @@ -314,15 +314,6 @@ pub enum Error { location: Location, }, - #[cfg(feature = "enterprise")] - #[snafu(display("The execution interval cannot be negative"))] - NegativeInterval { - #[snafu(source)] - error: std::num::TryFromIntError, - #[snafu(implicit)] - location: Location, - }, - #[cfg(feature = "enterprise")] #[snafu(display("Must specify at least one notify channel"))] MissingNotifyChannel { @@ -387,9 +378,7 @@ impl ErrorExt for Error { InvalidTriggerName { .. } => StatusCode::InvalidArguments, #[cfg(feature = "enterprise")] - InvalidTriggerWebhookOption { .. } | NegativeInterval { .. } => { - StatusCode::InvalidArguments - } + InvalidTriggerWebhookOption { .. } => StatusCode::InvalidArguments, SerializeColumnDefaultConstraint { source, .. } => source.status_code(), ConvertToGrpcDataType { source, .. } => source.status_code(), diff --git a/src/sql/src/parsers/alter_parser/trigger.rs b/src/sql/src/parsers/alter_parser/trigger.rs index d68b4864ad..b7a47bb8b0 100644 --- a/src/sql/src/parsers/alter_parser/trigger.rs +++ b/src/sql/src/parsers/alter_parser/trigger.rs @@ -55,8 +55,7 @@ impl<'a> ParserContext<'a> { let trigger_name = self.intern_parse_table_name()?; let mut new_trigger_name = None; - let mut new_query = None; - let mut new_interval = None; + let mut trigger_on = None; let mut label_ops = None; let mut annotation_ops = None; let mut notify_ops = None; @@ -78,13 +77,9 @@ impl<'a> ParserContext<'a> { } Token::Word(w) if w.value.eq_ignore_ascii_case(ON) => { self.parser.next_token(); - let (query, interval) = self.parse_trigger_on(true)?; - ensure!( - new_query.is_none() && new_interval.is_none(), - DuplicateClauseSnafu { clause: ON } - ); - new_query.replace(query); - new_interval.replace(interval); + let new_trigger_on = self.parse_trigger_on(true)?; + ensure!(trigger_on.is_none(), DuplicateClauseSnafu { clause: ON }); + trigger_on.replace(new_trigger_on); } Token::Word(w) if w.value.eq_ignore_ascii_case(LABELS) => { self.parser.next_token(); @@ -230,8 +225,7 @@ impl<'a> ParserContext<'a> { } if new_trigger_name.is_none() - && new_query.is_none() - && new_interval.is_none() + && trigger_on.is_none() && label_ops.is_none() && annotation_ops.is_none() && notify_ops.is_none() @@ -241,8 +235,7 @@ impl<'a> ParserContext<'a> { let operation = AlterTriggerOperation { rename: new_trigger_name, - new_query, - new_interval, + trigger_on, label_operations: label_ops, annotation_operations: annotation_ops, notify_channel_operations: notify_ops, @@ -544,10 +537,13 @@ fn apply_notify_change( #[cfg(test)] mod tests { + use std::time::Duration; + use crate::dialect::GreptimeDbDialect; use crate::parser::ParserContext; use crate::parsers::alter_parser::trigger::{apply_label_change, apply_label_replacement}; use crate::statements::alter::trigger::{LabelChange, LabelOperations}; + use crate::statements::create::trigger::TriggerOn; use crate::statements::statement::Statement; use crate::statements::OptionMap; @@ -571,9 +567,14 @@ mod tests { let Statement::AlterTrigger(alter) = stmt else { panic!("Expected AlterTrigger statement"); }; - assert!(alter.operation.new_query.is_some()); - assert!(alter.operation.new_interval.is_some()); - assert_eq!(alter.operation.new_interval.unwrap(), 300); + let TriggerOn { + query, + interval, + raw_interval_expr, + } = alter.operation.trigger_on.unwrap(); + assert_eq!(query.to_string(), "(SELECT * FROM test_table)"); + assert_eq!(raw_interval_expr, "'5 minute'::INTERVAL"); + assert_eq!(interval, Duration::from_secs(300)); assert!(alter.operation.rename.is_none()); assert!(alter.operation.label_operations.is_none()); assert!(alter.operation.annotation_operations.is_none()); diff --git a/src/sql/src/parsers/create_parser.rs b/src/sql/src/parsers/create_parser.rs index 2b7d789e48..280e357f6d 100644 --- a/src/sql/src/parsers/create_parser.rs +++ b/src/sql/src/parsers/create_parser.rs @@ -16,7 +16,9 @@ pub mod trigger; use std::collections::HashMap; +use std::time::Duration; +use arrow_buffer::IntervalMonthDayNano; use common_catalog::consts::default_engine; use datafusion_common::ScalarValue; use datatypes::arrow::datatypes::{DataType as ArrowDataType, IntervalUnit}; @@ -58,6 +60,8 @@ pub const AFTER: &str = "AFTER"; pub const INVERTED: &str = "INVERTED"; pub const SKIPPING: &str = "SKIPPING"; +pub type RawIntervalExpr = String; + /// Parses create [table] statement impl<'a> ParserContext<'a> { pub(crate) fn parse_create(&mut self) -> Result { @@ -348,7 +352,59 @@ impl<'a> ParserContext<'a> { /// Parse the interval expr to duration in seconds. fn parse_interval(&mut self) -> Result { + let interval = self.parse_interval_month_day_nano()?.0; + Ok( + interval.nanoseconds / 1_000_000_000 + + interval.days as i64 * 60 * 60 * 24 + + interval.months as i64 * 60 * 60 * 24 * 3044 / 1000, // 1 month=365.25/12=30.44 days + // this is to keep the same as https://docs.rs/humantime/latest/humantime/fn.parse_duration.html + // which we use in database to parse i.e. ttl interval and many other intervals + ) + } + + /// Parses an interval expression and converts it to a standard Rust [`Duration`] + /// and a raw interval expression string. + pub fn parse_interval_to_duration(&mut self) -> Result<(Duration, RawIntervalExpr)> { + let (interval, raw_interval_expr) = self.parse_interval_month_day_nano()?; + + let months: i64 = interval.months.into(); + let days: i64 = interval.days.into(); + let months_in_seconds: i64 = months * 60 * 60 * 24 * 3044 / 1000; + let days_in_seconds: i64 = days * 60 * 60 * 24; + let seconds_from_nanos = interval.nanoseconds / 1_000_000_000; + let total_seconds = months_in_seconds + days_in_seconds + seconds_from_nanos; + + let mut nanos_remainder = interval.nanoseconds % 1_000_000_000; + let mut adjusted_seconds = total_seconds; + + if nanos_remainder < 0 { + nanos_remainder += 1_000_000_000; + adjusted_seconds -= 1; + } + + ensure!( + adjusted_seconds >= 0, + InvalidIntervalSnafu { + reason: "must be a positive interval", + } + ); + + // Cast safety: `adjusted_seconds` is guaranteed to be non-negative before. + let adjusted_seconds = adjusted_seconds as u64; + // Cast safety: `nanos_remainder` is smaller than 1_000_000_000 which + // is checked above. + let nanos_remainder = nanos_remainder as u32; + + Ok(( + Duration::new(adjusted_seconds, nanos_remainder), + raw_interval_expr, + )) + } + + /// Parse interval expr to [`IntervalMonthDayNano`]. + fn parse_interval_month_day_nano(&mut self) -> Result<(IntervalMonthDayNano, RawIntervalExpr)> { let interval_expr = self.parser.parse_expr().context(error::SyntaxSnafu)?; + let raw_interval_expr = interval_expr.to_string(); let interval = utils::parser_expr_to_scalar_value_literal(interval_expr.clone())? .cast_to(&ArrowDataType::Interval(IntervalUnit::MonthDayNano)) .ok() @@ -356,13 +412,7 @@ impl<'a> ParserContext<'a> { reason: format!("cannot cast {} to interval type", interval_expr), })?; if let ScalarValue::IntervalMonthDayNano(Some(interval)) = interval { - Ok( - interval.nanoseconds / 1_000_000_000 - + interval.days as i64 * 60 * 60 * 24 - + interval.months as i64 * 60 * 60 * 24 * 3044 / 1000, // 1 month=365.25/12=30.44 days - // this is to keep the same as https://docs.rs/humantime/latest/humantime/fn.parse_duration.html - // which we use in database to parse i.e. ttl interval and many other intervals - ) + Ok((interval, raw_interval_expr)) } else { unreachable!() } diff --git a/src/sql/src/parsers/create_parser/trigger.rs b/src/sql/src/parsers/create_parser/trigger.rs index 26fced155a..7c6c89f4c4 100644 --- a/src/sql/src/parsers/create_parser/trigger.rs +++ b/src/sql/src/parsers/create_parser/trigger.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use snafu::{ensure, OptionExt, ResultExt}; -use sqlparser::ast::Query; use sqlparser::keywords::Keyword; use sqlparser::parser::Parser; use sqlparser::tokenizer::Token; @@ -10,7 +9,7 @@ use crate::error; use crate::error::Result; use crate::parser::ParserContext; use crate::statements::create::trigger::{ - AlertManagerWebhook, ChannelType, CreateTrigger, NotifyChannel, + AlertManagerWebhook, ChannelType, CreateTrigger, NotifyChannel, TriggerOn, }; use crate::statements::statement::Statement; use crate::statements::OptionMap; @@ -52,8 +51,7 @@ impl<'a> ParserContext<'a> { let if_not_exists = self.parse_if_not_exist()?; let trigger_name = self.intern_parse_table_name()?; - let mut may_query = None; - let mut may_interval = None; + let mut may_trigger_on = None; let mut may_labels = None; let mut may_annotations = None; let mut notify_channels = vec![]; @@ -63,9 +61,8 @@ impl<'a> ParserContext<'a> { match next_token.token { Token::Word(w) if w.value.eq_ignore_ascii_case(ON) => { self.parser.next_token(); - let (query, interval) = self.parse_trigger_on(true)?; - may_query.replace(query); - may_interval.replace(interval); + let trigger_on = self.parse_trigger_on(true)?; + may_trigger_on.replace(trigger_on); } Token::Word(w) if w.value.eq_ignore_ascii_case(LABELS) => { self.parser.next_token(); @@ -92,8 +89,7 @@ impl<'a> ParserContext<'a> { } } - let query = may_query.context(error::MissingClauseSnafu { name: ON })?; - let interval = may_interval.context(error::MissingClauseSnafu { name: ON })?; + let trigger_on = may_trigger_on.context(error::MissingClauseSnafu { name: ON })?; let labels = may_labels.unwrap_or_default(); let annotations = may_annotations.unwrap_or_default(); @@ -105,8 +101,7 @@ impl<'a> ParserContext<'a> { let create_trigger = CreateTrigger { trigger_name, if_not_exists, - query, - interval, + trigger_on, labels, annotations, channels: notify_channels, @@ -125,10 +120,7 @@ impl<'a> ParserContext<'a> { /// /// - `is_first_keyword_matched`: indicates whether the first keyword `ON` /// has been matched. - pub(crate) fn parse_trigger_on( - &mut self, - is_first_keyword_matched: bool, - ) -> Result<(Box, u64)> { + pub(crate) fn parse_trigger_on(&mut self, is_first_keyword_matched: bool) -> Result { if !is_first_keyword_matched { if let Token::Word(w) = self.parser.peek_token().token && w.value.eq_ignore_ascii_case(ON) @@ -149,12 +141,13 @@ impl<'a> ParserContext<'a> { return self.expected("`EVERY` keyword", self.parser.peek_token()); } - let interval = self - .parse_interval()? - .try_into() - .context(error::NegativeIntervalSnafu)?; + let (interval, raw_interval_expr) = self.parse_interval_to_duration()?; - Ok((query, interval)) + Ok(TriggerOn { + query, + interval, + raw_interval_expr, + }) } /// The SQL format as follows: @@ -380,6 +373,8 @@ impl<'a> ParserContext<'a> { #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; use crate::dialect::GreptimeDbDialect; use crate::statements::create::trigger::ChannelType; @@ -452,10 +447,20 @@ IF NOT EXISTS cpu_monitor assert!(create_trigger.if_not_exists); assert_eq!(create_trigger.trigger_name.to_string(), "cpu_monitor"); assert_eq!( - create_trigger.query.to_string(), + create_trigger.trigger_on.query.to_string(), "(SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 1)" ); - assert_eq!(create_trigger.interval, 300); + let TriggerOn { + query, + interval, + raw_interval_expr, + } = &create_trigger.trigger_on; + assert_eq!( + query.to_string(), + "(SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 1)" + ); + assert_eq!(*interval, Duration::from_secs(300)); + assert_eq!(raw_interval_expr.to_string(), "'5 minute'::INTERVAL"); assert_eq!(create_trigger.labels.len(), 1); assert_eq!( create_trigger.labels.get("label_name").unwrap(), @@ -487,9 +492,14 @@ IF NOT EXISTS cpu_monitor // Normal. let sql = "ON (SELECT * FROM cpu_usage) EVERY '5 minute'::INTERVAL"; let mut ctx = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap(); - let (query, interval) = ctx.parse_trigger_on(false).unwrap(); + let TriggerOn { + query, + interval, + raw_interval_expr: raw_interval, + } = ctx.parse_trigger_on(false).unwrap(); assert_eq!(query.to_string(), "(SELECT * FROM cpu_usage)"); - assert_eq!(interval, 300); + assert_eq!(interval, Duration::from_secs(300)); + assert_eq!(raw_interval, "'5 minute'::INTERVAL"); // Invalid, since missing `ON` keyword. let sql = "SELECT * FROM cpu_usage EVERY '5 minute'::INTERVAL"; diff --git a/src/sql/src/statements/alter/trigger.rs b/src/sql/src/statements/alter/trigger.rs index e7488cbeba..93c02e7158 100644 --- a/src/sql/src/statements/alter/trigger.rs +++ b/src/sql/src/statements/alter/trigger.rs @@ -1,10 +1,10 @@ use std::fmt::{Display, Formatter}; use serde::Serialize; -use sqlparser::ast::{ObjectName, Query}; +use sqlparser::ast::ObjectName; use sqlparser_derive::{Visit, VisitMut}; -use crate::statements::create::trigger::NotifyChannel; +use crate::statements::create::trigger::{NotifyChannel, TriggerOn}; use crate::statements::OptionMap; #[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)] @@ -16,9 +16,7 @@ pub struct AlterTrigger { #[derive(Debug, Default, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)] pub struct AlterTriggerOperation { pub rename: Option, - pub new_query: Option>, - /// The new interval of exec query. Unit is second. - pub new_interval: Option, + pub trigger_on: Option, pub label_operations: Option, pub annotation_operations: Option, pub notify_channel_operations: Option, @@ -35,12 +33,9 @@ impl Display for AlterTrigger { write!(f, "RENAME TO {}", new_name)?; } - if let Some((new_query, new_interval)) = - operation.new_query.as_ref().zip(operation.new_interval) - { + if let Some(trigger_on) = &operation.trigger_on { writeln!(f)?; - write!(f, "ON {}", new_query)?; - write!(f, " EVERY {} SECONDS", new_interval)?; + write!(f, "{}", trigger_on)?; } if let Some(label_ops) = &operation.label_operations { @@ -319,7 +314,7 @@ ADD NOTIFY let formatted = format!("{}", trigger); let expected = r#"ALTER TRIGGER my_trigger RENAME TO new_trigger -ON (SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2) EVERY 300 SECONDS +ON (SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2) EVERY '5 minute'::INTERVAL ADD LABELS (k1 = 'v1', k2 = 'v2') DROP LABELS (k3, k4) SET ANNOTATIONS (a1 = 'v1', a2 = 'v2') diff --git a/src/sql/src/statements/create/trigger.rs b/src/sql/src/statements/create/trigger.rs index 999b2a30af..a6cc3d88d8 100644 --- a/src/sql/src/statements/create/trigger.rs +++ b/src/sql/src/statements/create/trigger.rs @@ -1,8 +1,10 @@ use std::fmt::{Display, Formatter}; +use std::ops::ControlFlow; +use std::time::Duration; use itertools::Itertools; use serde::Serialize; -use sqlparser::ast::Query; +use sqlparser::ast::{Query, Visit, VisitMut, Visitor, VisitorMut}; use sqlparser_derive::{Visit, VisitMut}; use crate::ast::{Ident, ObjectName}; @@ -13,10 +15,7 @@ use crate::statements::OptionMap; pub struct CreateTrigger { pub trigger_name: ObjectName, pub if_not_exists: bool, - /// SQL statement executed periodically. - pub query: Box, - /// The interval of exec query. Unit is second. - pub interval: u64, + pub trigger_on: TriggerOn, pub labels: OptionMap, pub annotations: OptionMap, pub channels: Vec, @@ -29,8 +28,7 @@ impl Display for CreateTrigger { write!(f, "IF NOT EXISTS ")?; } writeln!(f, "{}", self.trigger_name)?; - write!(f, "ON {} ", self.query)?; - writeln!(f, "EVERY {} SECONDS", self.interval)?; + writeln!(f, "{}", self.trigger_on)?; if !self.labels.is_empty() { let labels = self.labels.kv_pairs(); @@ -73,6 +71,33 @@ impl Display for NotifyChannel { } } +#[derive(Debug, PartialEq, Eq, Clone, Serialize)] +pub struct TriggerOn { + pub query: Box, + pub interval: Duration, + pub raw_interval_expr: String, +} + +impl Display for TriggerOn { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ON {} EVERY {}", self.query, self.raw_interval_expr) + } +} + +impl Visit for TriggerOn { + fn visit(&self, visitor: &mut V) -> ControlFlow { + Visit::visit(&self.query, visitor)?; + ControlFlow::Continue(()) + } +} + +impl VisitMut for TriggerOn { + fn visit(&mut self, visitor: &mut V) -> ControlFlow { + VisitMut::visit(&mut self.query, visitor)?; + ControlFlow::Continue(()) + } +} + #[derive(Debug, PartialEq, Eq, Clone, Visit, VisitMut, Serialize)] pub enum ChannelType { /// Alert manager webhook options. @@ -94,7 +119,7 @@ mod tests { #[test] fn test_display_create_trigger() { let sql = r#"CREATE TRIGGER IF NOT EXISTS cpu_monitor -ON (SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2) EVERY '5 minute'::INTERVAL +ON (SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2) EVERY '1day 5 minute'::INTERVAL LABELS (label_name=label_val) ANNOTATIONS (annotation_name=annotation_val) NOTIFY @@ -110,7 +135,7 @@ NOTIFY }; let formatted = format!("{}", trigger); let expected = r#"CREATE TRIGGER IF NOT EXISTS cpu_monitor -ON (SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2) EVERY 300 SECONDS +ON (SELECT host AS host_label, cpu, memory FROM machine_monitor WHERE cpu > 2) EVERY '1day 5 minute'::INTERVAL LABELS (label_name = 'label_val') ANNOTATIONS (annotation_name = 'annotation_val') NOTIFY(