Skip to main content

sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21    AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22    Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23    SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24    VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29    Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, TableFactor,
30    TableWithJoins, Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::parser::ParserContext;
37use crate::parsers::with_tql_parser::CteContent;
38use crate::statements::create::SqlOrTql;
39use crate::statements::query::Query;
40use crate::statements::tql::Tql;
41
42const SCHEMA_MATCHER: &str = "__schema__";
43const DATABASE_MATCHER: &str = "__database__";
44
45/// Format an [ObjectName] without any quote of its idents.
46pub fn format_raw_object_name(name: &ObjectName) -> String {
47    struct Inner<'a> {
48        name: &'a ObjectName,
49    }
50
51    impl Display for Inner<'_> {
52        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53            let mut delim = "";
54            for ident in self.name.0.iter() {
55                write!(f, "{delim}")?;
56                delim = ".";
57                write!(f, "{}", ident.to_string_unquoted())?;
58            }
59            Ok(())
60        }
61    }
62
63    format!("{}", Inner { name })
64}
65
66#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
67pub struct OptionValue(Expr);
68
69impl OptionValue {
70    pub(crate) fn try_new(expr: Expr) -> Result<Self> {
71        ensure!(
72            matches!(
73                expr,
74                Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
75            ),
76            InvalidExprAsOptionValueSnafu {
77                error: format!("{expr} not accepted")
78            }
79        );
80        Ok(Self(expr))
81    }
82
83    fn expr_as_string(expr: &Expr) -> Option<&str> {
84        match expr {
85            Expr::Value(ValueWithSpan { value, .. }) => match value {
86                Value::SingleQuotedString(s)
87                | Value::DoubleQuotedString(s)
88                | Value::TripleSingleQuotedString(s)
89                | Value::TripleDoubleQuotedString(s)
90                | Value::SingleQuotedByteStringLiteral(s)
91                | Value::DoubleQuotedByteStringLiteral(s)
92                | Value::TripleSingleQuotedByteStringLiteral(s)
93                | Value::TripleDoubleQuotedByteStringLiteral(s)
94                | Value::SingleQuotedRawStringLiteral(s)
95                | Value::DoubleQuotedRawStringLiteral(s)
96                | Value::TripleSingleQuotedRawStringLiteral(s)
97                | Value::TripleDoubleQuotedRawStringLiteral(s)
98                | Value::EscapedStringLiteral(s)
99                | Value::UnicodeStringLiteral(s)
100                | Value::NationalStringLiteral(s)
101                | Value::HexStringLiteral(s) => Some(s),
102                Value::DollarQuotedString(s) => Some(&s.value),
103                Value::Number(s, _) => Some(s),
104                Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
105                _ => None,
106            },
107            Expr::Identifier(ident) => Some(&ident.value),
108            _ => None,
109        }
110    }
111
112    /// Convert the option value to a string.
113    ///
114    /// Notes: Not all values can be converted to a string, refer to [Self::expr_as_string] for more details.
115    pub fn as_string(&self) -> Option<&str> {
116        Self::expr_as_string(&self.0)
117    }
118
119    pub fn as_list(&self) -> Option<Vec<&str>> {
120        let expr = &self.0;
121        match expr {
122            Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
123            Expr::Array(array) => array
124                .elem
125                .iter()
126                .map(Self::expr_as_string)
127                .collect::<Option<Vec<_>>>(),
128            _ => None,
129        }
130    }
131}
132
133impl From<String> for OptionValue {
134    fn from(value: String) -> Self {
135        Self(Expr::Identifier(Ident::new(value)))
136    }
137}
138
139impl From<&str> for OptionValue {
140    fn from(value: &str) -> Self {
141        Self(Expr::Identifier(Ident::new(value)))
142    }
143}
144
145impl From<Vec<&str>> for OptionValue {
146    fn from(value: Vec<&str>) -> Self {
147        Self(Expr::Array(Array {
148            elem: value
149                .into_iter()
150                .map(|x| Expr::Identifier(Ident::new(x)))
151                .collect(),
152            named: false,
153        }))
154    }
155}
156
157impl Display for OptionValue {
158    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
159        if let Some(s) = self.as_string() {
160            write!(f, "'{s}'")
161        } else if let Some(s) = self.as_list() {
162            write!(
163                f,
164                "[{}]",
165                s.into_iter().map(|x| format!("'{x}'")).join(", ")
166            )
167        } else {
168            write!(f, "'{}'", self.0)
169        }
170    }
171}
172
173pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
174    let SqlOption::KeyValue { key, value } = option else {
175        return InvalidSqlSnafu {
176            msg: "Expecting a key-value pair in the option",
177        }
178        .fail();
179    };
180    let v = OptionValue::try_new(value)?;
181    let k = key.value.to_lowercase();
182    Ok((k, v))
183}
184
185/// Walk through a [Query] and extract all the tables referenced in it.
186pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
187    let mut names = HashSet::new();
188
189    match query {
190        SqlOrTql::Sql(query, _) => {
191            extract_tables_from_sql_query(&query.inner, &mut names);
192            extract_tables_from_hybrid_cte_query(query, &mut names);
193        }
194        SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
195    }
196
197    names.into_iter()
198}
199
200fn extract_tables_from_hybrid_cte_query(query: &Query, sql_names: &mut HashSet<ObjectName>) {
201    if let Some(hybrid_cte) = &query.hybrid_cte {
202        let mut cte_names: HashSet<String> = hybrid_cte
203            .cte_tables
204            .iter()
205            .map(|cte| ParserContext::canonicalize_identifier(cte.name.clone()).value)
206            .collect();
207        remove_cte_names(sql_names, &cte_names);
208
209        cte_names.clear();
210        for cte in &hybrid_cte.cte_tables {
211            let cte_name = ParserContext::canonicalize_identifier(cte.name.clone()).value;
212            let mut cte_query_names = HashSet::new();
213            match &cte.content {
214                CteContent::Sql(cte_query) => {
215                    extract_tables_from_sql_query(cte_query, &mut cte_query_names)
216                }
217                CteContent::Tql(tql) => extract_tables_from_tql(tql, &mut cte_query_names),
218            }
219            if hybrid_cte.recursive {
220                cte_names.insert(cte_name.clone());
221            }
222            remove_cte_names(&mut cte_query_names, &cte_names);
223            sql_names.extend(cte_query_names);
224            if !hybrid_cte.recursive {
225                cte_names.insert(cte_name);
226            }
227        }
228    }
229}
230
231fn remove_cte_names(names: &mut HashSet<ObjectName>, cte_names: &HashSet<String>) {
232    if cte_names.is_empty() {
233        return;
234    }
235
236    names.retain(|name| {
237        if name.0.len() != 1 {
238            return true;
239        }
240        let Some(ident) = name.0[0].as_ident() else {
241            return true;
242        };
243
244        let canonical = ParserContext::canonicalize_identifier(ident.clone()).value;
245        !cte_names.contains(&canonical)
246    });
247}
248
249fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
250    let promql = match tql {
251        Tql::Eval(eval) => &eval.query,
252        Tql::Explain(explain) => &explain.query,
253        Tql::Analyze(analyze) => &analyze.query,
254    };
255
256    if let Ok(expr) = promql_parser::parser::parse(promql) {
257        extract_tables_from_prom_expr(&expr, names);
258    }
259}
260
261fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
262    match expr {
263        PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
264            extract_tables_from_prom_expr(expr, names);
265        }
266        PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
267            extract_tables_from_prom_expr(expr, names);
268        }
269        PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
270            extract_tables_from_prom_expr(lhs, names);
271            extract_tables_from_prom_expr(rhs, names);
272        }
273        PromExpr::Paren(PromParenExpr { expr }) => {
274            extract_tables_from_prom_expr(expr, names);
275        }
276        PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
277            extract_tables_from_prom_expr(expr, names);
278        }
279        PromExpr::VectorSelector(selector) => {
280            extract_metric_name_from_vector_selector(selector, names);
281        }
282        PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
283            extract_metric_name_from_vector_selector(vs, names);
284        }
285        PromExpr::Call(PromCall { args, .. }) => {
286            for arg in &args.args {
287                extract_tables_from_prom_expr(arg, names);
288            }
289        }
290        PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
291    }
292}
293
294fn extract_metric_name_from_vector_selector(
295    selector: &PromVectorSelector,
296    names: &mut HashSet<ObjectName>,
297) {
298    let metric_name = selector.name.clone().or_else(|| {
299        let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
300        if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
301            metric_name_matchers.pop().map(|matcher| matcher.value)
302        } else {
303            None
304        }
305    });
306    let Some(metric_name) = metric_name else {
307        return;
308    };
309
310    let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
311        matcher.op == MatchOp::Equal
312            && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
313    });
314
315    if let Some(schema) = schema_matcher {
316        names.insert(ObjectName(vec![
317            ObjectNamePart::Identifier(Ident::new(&schema.value)),
318            ObjectNamePart::Identifier(Ident::new(metric_name)),
319        ]));
320    } else {
321        names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
322            metric_name,
323        ))]));
324    }
325}
326
327/// translate the start location to the index in the sql string
328pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
329    let mut index = 0;
330    for (lno, line) in sql.lines().enumerate() {
331        if lno + 1 == location.line as usize {
332            index += location.column as usize;
333            break;
334        } else {
335            index += line.len() + 1; // +1 for the newline
336        }
337    }
338    // -1 because the index is 0-based
339    // and the location is 1-based
340    index - 1
341}
342
343/// Helper function for [extract_tables_from_query].
344///
345/// Handle [sqlparser::ast::Query].
346fn extract_tables_from_sql_query(query: &sqlparser::ast::Query, names: &mut HashSet<ObjectName>) {
347    let mut cte_names = HashSet::new();
348    if let Some(with) = &query.with {
349        for cte in &with.cte_tables {
350            let cte_name = ParserContext::canonicalize_identifier(cte.alias.name.clone()).value;
351            let mut cte_query_names = HashSet::new();
352            extract_tables_from_sql_query(&cte.query, &mut cte_query_names);
353            if with.recursive {
354                cte_names.insert(cte_name.clone());
355            }
356            remove_cte_names(&mut cte_query_names, &cte_names);
357            names.extend(cte_query_names);
358            if !with.recursive {
359                cte_names.insert(cte_name);
360            }
361        }
362    }
363
364    let mut body_names = HashSet::new();
365    extract_tables_from_set_expr(&query.body, &mut body_names);
366    remove_cte_names(&mut body_names, &cte_names);
367    names.extend(body_names);
368}
369
370/// Helper function for [extract_tables_from_query].
371///
372/// Handle [SetExpr].
373fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
374    match set_expr {
375        SetExpr::Select(select) => {
376            for from in &select.from {
377                extract_tables_from_table_with_joins(from, names);
378            }
379        }
380        SetExpr::Query(query) => {
381            extract_tables_from_sql_query(query, names);
382        }
383        SetExpr::SetOperation { left, right, .. } => {
384            extract_tables_from_set_expr(left, names);
385            extract_tables_from_set_expr(right, names);
386        }
387        _ => {}
388    };
389}
390
391/// Helper function for [extract_tables_from_query].
392///
393/// Handle [TableWithJoins].
394fn extract_tables_from_table_with_joins(
395    table_with_joins: &TableWithJoins,
396    names: &mut HashSet<ObjectName>,
397) {
398    table_factor_to_object_name(&table_with_joins.relation, names);
399    for join in &table_with_joins.joins {
400        table_factor_to_object_name(&join.relation, names);
401    }
402}
403
404/// Helper function for [extract_tables_from_query].
405///
406/// Handle [TableFactor].
407fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
408    match table_factor {
409        TableFactor::Table { name, .. } => {
410            names.insert(name.to_owned());
411        }
412        TableFactor::Derived { subquery, .. } => {
413            extract_tables_from_sql_query(subquery, names);
414        }
415        TableFactor::NestedJoin {
416            table_with_joins, ..
417        } => {
418            extract_tables_from_table_with_joins(table_with_joins, names);
419        }
420        TableFactor::Pivot { table, .. }
421        | TableFactor::Unpivot { table, .. }
422        | TableFactor::MatchRecognize { table, .. } => {
423            table_factor_to_object_name(table, names);
424        }
425        TableFactor::TableFunction { .. }
426        | TableFactor::Function { .. }
427        | TableFactor::UNNEST { .. }
428        | TableFactor::JsonTable { .. }
429        | TableFactor::OpenJsonTable { .. }
430        | TableFactor::XmlTable { .. }
431        | TableFactor::SemanticView { .. } => {}
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use sqlparser::tokenizer::Token;
438
439    use super::*;
440    use crate::dialect::GreptimeDbDialect;
441    use crate::parser::{ParseOptions, ParserContext};
442    use crate::statements::statement::Statement;
443
444    #[test]
445    fn test_location_to_index() {
446        let testcases = vec![
447            "SELECT * FROM t WHERE a = 1",
448            // start or end with newline
449            r"
450SELECT *
451FROM
452t
453WHERE a =
4541
455",
456            r"SELECT *
457FROM
458t
459WHERE a =
4601
461",
462            r"
463SELECT *
464FROM
465t
466WHERE a =
4671",
468        ];
469
470        for sql in testcases {
471            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
472            loop {
473                let token = parser.parser.next_token();
474                if token == Token::EOF {
475                    break;
476                }
477                let span = token.span;
478                let subslice =
479                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
480                assert_eq!(token.to_string(), subslice);
481            }
482        }
483    }
484
485    #[test]
486    fn test_extract_tables_from_tql_query() {
487        let testcases = vec![
488            (
489                r#"
490CREATE FLOW calc_reqs SINK TO cnt_reqs AS
491TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
492                vec!["http_requests".to_string()],
493            ),
494            (
495                r#"
496CREATE FLOW calc_reqs SINK TO cnt_reqs AS
497TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
498                vec!["http_requests".to_string()],
499            ),
500        ];
501
502        for (sql, expected_tables) in testcases {
503            let mut stmts = ParserContext::create_with_dialect(
504                sql,
505                &GreptimeDbDialect {},
506                ParseOptions::default(),
507            )
508            .unwrap();
509            let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
510                unreachable!()
511            };
512
513            let mut tables = extract_tables_from_query(&create_flow.query)
514                .map(|table| format_raw_object_name(&table))
515                .collect_vec();
516            tables.sort();
517            assert_eq!(expected_tables, tables);
518        }
519    }
520
521    #[test]
522    fn test_extract_tables_from_sql_query_with_derived_join() {
523        let sql = r#"
524CREATE FLOW flow_batch_join_subquery SINK TO flow_batch_join_sink
525EVAL INTERVAL '1m' AS
526SELECT a.symbol, b.mark_price
527FROM (
528    SELECT inst_id AS symbol, max(ts) AS mark_iv_ts
529    FROM flow_batch_join_opt_summary
530    GROUP BY inst_id
531) a
532LEFT JOIN (
533    SELECT symbol, max(mark_price) AS mark_price
534    FROM flow_batch_join_market_v5
535    WHERE "type" = 'OPTION_MARK'
536    GROUP BY symbol
537) b ON a.symbol = b.symbol;
538"#;
539        let mut stmts =
540            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
541                .unwrap();
542        let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
543            unreachable!()
544        };
545
546        let mut tables = extract_tables_from_query(&create_flow.query)
547            .map(|table| format_raw_object_name(&table))
548            .collect_vec();
549        tables.sort();
550        assert_eq!(
551            vec![
552                "flow_batch_join_market_v5".to_string(),
553                "flow_batch_join_opt_summary".to_string(),
554            ],
555            tables
556        );
557    }
558
559    #[test]
560    fn test_extract_tables_from_sql_query_with_cte_scopes() {
561        let testcases = vec![
562            (
563                r#"
564WITH source AS (
565    SELECT * FROM source
566)
567SELECT * FROM source;
568"#,
569                vec!["source".to_string()],
570            ),
571            (
572                r#"
573WITH first_cte AS (
574    SELECT * FROM physical_source
575), second_cte AS (
576    SELECT * FROM first_cte
577)
578SELECT * FROM second_cte;
579"#,
580                vec!["physical_source".to_string()],
581            ),
582        ];
583
584        for (sql, expected_tables) in testcases {
585            let mut stmts = ParserContext::create_with_dialect(
586                sql,
587                &GreptimeDbDialect {},
588                ParseOptions::default(),
589            )
590            .unwrap();
591            let Statement::Query(query) = stmts.pop().unwrap() else {
592                unreachable!()
593            };
594
595            let mut tables = HashSet::new();
596            extract_tables_from_sql_query(&query.inner, &mut tables);
597            let mut tables = tables
598                .into_iter()
599                .map(|table| format_raw_object_name(&table))
600                .collect_vec();
601            tables.sort();
602            assert_eq!(expected_tables, tables);
603        }
604    }
605
606    #[test]
607    fn test_extract_tables_from_tql_query_with_schema_matcher() {
608        let sql = r#"
609CREATE FLOW calc_reqs SINK TO cnt_reqs AS
610TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
611        let mut stmts =
612            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
613                .unwrap();
614        let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
615            unreachable!()
616        };
617
618        let mut tables = extract_tables_from_query(&create_flow.query)
619            .map(|table| format_raw_object_name(&table))
620            .collect_vec();
621        tables.sort();
622        assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
623    }
624}