Skip to main content

sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21    AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22    Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23    SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24    VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29    Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, StructField, TableFactor,
30    TableWithJoins, Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::parser::ParserContext;
37use crate::parsers::with_tql_parser::CteContent;
38use crate::statements::create::SqlOrTql;
39use crate::statements::query::Query;
40use crate::statements::tql::Tql;
41
42const SCHEMA_MATCHER: &str = "__schema__";
43const DATABASE_MATCHER: &str = "__database__";
44
45/// Format an [ObjectName] without any quote of its idents.
46pub fn format_raw_object_name(name: &ObjectName) -> String {
47    struct Inner<'a> {
48        name: &'a ObjectName,
49    }
50
51    impl Display for Inner<'_> {
52        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53            let mut delim = "";
54            for ident in self.name.0.iter() {
55                write!(f, "{delim}")?;
56                delim = ".";
57                write!(f, "{}", ident.to_string_unquoted())?;
58            }
59            Ok(())
60        }
61    }
62
63    format!("{}", Inner { name })
64}
65
66#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
67pub struct OptionValue(Expr);
68
69impl OptionValue {
70    pub(crate) fn try_new(expr: Expr) -> Result<Self> {
71        ensure!(
72            matches!(
73                expr,
74                Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
75            ),
76            InvalidExprAsOptionValueSnafu {
77                error: format!("{expr} not accepted")
78            }
79        );
80        Ok(Self(expr))
81    }
82
83    fn expr_as_string(expr: &Expr) -> Option<&str> {
84        match expr {
85            Expr::Value(ValueWithSpan { value, .. }) => match value {
86                Value::SingleQuotedString(s)
87                | Value::DoubleQuotedString(s)
88                | Value::TripleSingleQuotedString(s)
89                | Value::TripleDoubleQuotedString(s)
90                | Value::SingleQuotedByteStringLiteral(s)
91                | Value::DoubleQuotedByteStringLiteral(s)
92                | Value::TripleSingleQuotedByteStringLiteral(s)
93                | Value::TripleDoubleQuotedByteStringLiteral(s)
94                | Value::SingleQuotedRawStringLiteral(s)
95                | Value::DoubleQuotedRawStringLiteral(s)
96                | Value::TripleSingleQuotedRawStringLiteral(s)
97                | Value::TripleDoubleQuotedRawStringLiteral(s)
98                | Value::EscapedStringLiteral(s)
99                | Value::UnicodeStringLiteral(s)
100                | Value::NationalStringLiteral(s)
101                | Value::HexStringLiteral(s) => Some(s),
102                Value::DollarQuotedString(s) => Some(&s.value),
103                Value::Number(s, _) => Some(s),
104                Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
105                _ => None,
106            },
107            Expr::Identifier(ident) => Some(&ident.value),
108            _ => None,
109        }
110    }
111
112    /// Convert the option value to a string.
113    ///
114    /// Notes: Not all values can be converted to a string, refer to [Self::expr_as_string] for more details.
115    pub fn as_string(&self) -> Option<&str> {
116        Self::expr_as_string(&self.0)
117    }
118
119    pub fn as_list(&self) -> Option<Vec<&str>> {
120        let expr = &self.0;
121        match expr {
122            Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
123            Expr::Array(array) => array
124                .elem
125                .iter()
126                .map(Self::expr_as_string)
127                .collect::<Option<Vec<_>>>(),
128            _ => None,
129        }
130    }
131
132    pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
133        match &self.0 {
134            Expr::Struct { fields, .. } => Some(fields),
135            _ => None,
136        }
137    }
138}
139
140impl From<String> for OptionValue {
141    fn from(value: String) -> Self {
142        Self(Expr::Identifier(Ident::new(value)))
143    }
144}
145
146impl From<&str> for OptionValue {
147    fn from(value: &str) -> Self {
148        Self(Expr::Identifier(Ident::new(value)))
149    }
150}
151
152impl From<Vec<&str>> for OptionValue {
153    fn from(value: Vec<&str>) -> Self {
154        Self(Expr::Array(Array {
155            elem: value
156                .into_iter()
157                .map(|x| Expr::Identifier(Ident::new(x)))
158                .collect(),
159            named: false,
160        }))
161    }
162}
163
164impl Display for OptionValue {
165    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166        if let Some(s) = self.as_string() {
167            write!(f, "'{s}'")
168        } else if let Some(s) = self.as_list() {
169            write!(
170                f,
171                "[{}]",
172                s.into_iter().map(|x| format!("'{x}'")).join(", ")
173            )
174        } else {
175            write!(f, "'{}'", self.0)
176        }
177    }
178}
179
180pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
181    let SqlOption::KeyValue { key, value } = option else {
182        return InvalidSqlSnafu {
183            msg: "Expecting a key-value pair in the option",
184        }
185        .fail();
186    };
187    let v = OptionValue::try_new(value)?;
188    let k = key.value.to_lowercase();
189    Ok((k, v))
190}
191
192/// Walk through a [Query] and extract all the tables referenced in it.
193pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
194    let mut names = HashSet::new();
195
196    match query {
197        SqlOrTql::Sql(query, _) => {
198            extract_tables_from_sql_query(&query.inner, &mut names);
199            extract_tables_from_hybrid_cte_query(query, &mut names);
200        }
201        SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
202    }
203
204    names.into_iter()
205}
206
207fn extract_tables_from_hybrid_cte_query(query: &Query, sql_names: &mut HashSet<ObjectName>) {
208    if let Some(hybrid_cte) = &query.hybrid_cte {
209        let mut cte_names: HashSet<String> = hybrid_cte
210            .cte_tables
211            .iter()
212            .map(|cte| ParserContext::canonicalize_identifier(cte.name.clone()).value)
213            .collect();
214        remove_cte_names(sql_names, &cte_names);
215
216        cte_names.clear();
217        for cte in &hybrid_cte.cte_tables {
218            let cte_name = ParserContext::canonicalize_identifier(cte.name.clone()).value;
219            let mut cte_query_names = HashSet::new();
220            match &cte.content {
221                CteContent::Sql(cte_query) => {
222                    extract_tables_from_sql_query(cte_query, &mut cte_query_names)
223                }
224                CteContent::Tql(tql) => extract_tables_from_tql(tql, &mut cte_query_names),
225            }
226            if hybrid_cte.recursive {
227                cte_names.insert(cte_name.clone());
228            }
229            remove_cte_names(&mut cte_query_names, &cte_names);
230            sql_names.extend(cte_query_names);
231            if !hybrid_cte.recursive {
232                cte_names.insert(cte_name);
233            }
234        }
235    }
236}
237
238fn remove_cte_names(names: &mut HashSet<ObjectName>, cte_names: &HashSet<String>) {
239    if cte_names.is_empty() {
240        return;
241    }
242
243    names.retain(|name| {
244        if name.0.len() != 1 {
245            return true;
246        }
247        let Some(ident) = name.0[0].as_ident() else {
248            return true;
249        };
250
251        let canonical = ParserContext::canonicalize_identifier(ident.clone()).value;
252        !cte_names.contains(&canonical)
253    });
254}
255
256fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
257    let promql = match tql {
258        Tql::Eval(eval) => &eval.query,
259        Tql::Explain(explain) => &explain.query,
260        Tql::Analyze(analyze) => &analyze.query,
261    };
262
263    if let Ok(expr) = promql_parser::parser::parse(promql) {
264        extract_tables_from_prom_expr(&expr, names);
265    }
266}
267
268fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
269    match expr {
270        PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
271            extract_tables_from_prom_expr(expr, names);
272        }
273        PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
274            extract_tables_from_prom_expr(expr, names);
275        }
276        PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
277            extract_tables_from_prom_expr(lhs, names);
278            extract_tables_from_prom_expr(rhs, names);
279        }
280        PromExpr::Paren(PromParenExpr { expr }) => {
281            extract_tables_from_prom_expr(expr, names);
282        }
283        PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
284            extract_tables_from_prom_expr(expr, names);
285        }
286        PromExpr::VectorSelector(selector) => {
287            extract_metric_name_from_vector_selector(selector, names);
288        }
289        PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
290            extract_metric_name_from_vector_selector(vs, names);
291        }
292        PromExpr::Call(PromCall { args, .. }) => {
293            for arg in &args.args {
294                extract_tables_from_prom_expr(arg, names);
295            }
296        }
297        PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
298    }
299}
300
301fn extract_metric_name_from_vector_selector(
302    selector: &PromVectorSelector,
303    names: &mut HashSet<ObjectName>,
304) {
305    let metric_name = selector.name.clone().or_else(|| {
306        let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
307        if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
308            metric_name_matchers.pop().map(|matcher| matcher.value)
309        } else {
310            None
311        }
312    });
313    let Some(metric_name) = metric_name else {
314        return;
315    };
316
317    let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
318        matcher.op == MatchOp::Equal
319            && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
320    });
321
322    if let Some(schema) = schema_matcher {
323        names.insert(ObjectName(vec![
324            ObjectNamePart::Identifier(Ident::new(&schema.value)),
325            ObjectNamePart::Identifier(Ident::new(metric_name)),
326        ]));
327    } else {
328        names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
329            metric_name,
330        ))]));
331    }
332}
333
334/// translate the start location to the index in the sql string
335pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
336    let mut index = 0;
337    for (lno, line) in sql.lines().enumerate() {
338        if lno + 1 == location.line as usize {
339            index += location.column as usize;
340            break;
341        } else {
342            index += line.len() + 1; // +1 for the newline
343        }
344    }
345    // -1 because the index is 0-based
346    // and the location is 1-based
347    index - 1
348}
349
350/// Helper function for [extract_tables_from_query].
351///
352/// Handle [sqlparser::ast::Query].
353fn extract_tables_from_sql_query(query: &sqlparser::ast::Query, names: &mut HashSet<ObjectName>) {
354    let mut cte_names = HashSet::new();
355    if let Some(with) = &query.with {
356        for cte in &with.cte_tables {
357            let cte_name = ParserContext::canonicalize_identifier(cte.alias.name.clone()).value;
358            let mut cte_query_names = HashSet::new();
359            extract_tables_from_sql_query(&cte.query, &mut cte_query_names);
360            if with.recursive {
361                cte_names.insert(cte_name.clone());
362            }
363            remove_cte_names(&mut cte_query_names, &cte_names);
364            names.extend(cte_query_names);
365            if !with.recursive {
366                cte_names.insert(cte_name);
367            }
368        }
369    }
370
371    let mut body_names = HashSet::new();
372    extract_tables_from_set_expr(&query.body, &mut body_names);
373    remove_cte_names(&mut body_names, &cte_names);
374    names.extend(body_names);
375}
376
377/// Helper function for [extract_tables_from_query].
378///
379/// Handle [SetExpr].
380fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
381    match set_expr {
382        SetExpr::Select(select) => {
383            for from in &select.from {
384                extract_tables_from_table_with_joins(from, names);
385            }
386        }
387        SetExpr::Query(query) => {
388            extract_tables_from_sql_query(query, names);
389        }
390        SetExpr::SetOperation { left, right, .. } => {
391            extract_tables_from_set_expr(left, names);
392            extract_tables_from_set_expr(right, names);
393        }
394        _ => {}
395    };
396}
397
398/// Helper function for [extract_tables_from_query].
399///
400/// Handle [TableWithJoins].
401fn extract_tables_from_table_with_joins(
402    table_with_joins: &TableWithJoins,
403    names: &mut HashSet<ObjectName>,
404) {
405    table_factor_to_object_name(&table_with_joins.relation, names);
406    for join in &table_with_joins.joins {
407        table_factor_to_object_name(&join.relation, names);
408    }
409}
410
411/// Helper function for [extract_tables_from_query].
412///
413/// Handle [TableFactor].
414fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
415    match table_factor {
416        TableFactor::Table { name, .. } => {
417            names.insert(name.to_owned());
418        }
419        TableFactor::Derived { subquery, .. } => {
420            extract_tables_from_sql_query(subquery, names);
421        }
422        TableFactor::NestedJoin {
423            table_with_joins, ..
424        } => {
425            extract_tables_from_table_with_joins(table_with_joins, names);
426        }
427        TableFactor::Pivot { table, .. }
428        | TableFactor::Unpivot { table, .. }
429        | TableFactor::MatchRecognize { table, .. } => {
430            table_factor_to_object_name(table, names);
431        }
432        TableFactor::TableFunction { .. }
433        | TableFactor::Function { .. }
434        | TableFactor::UNNEST { .. }
435        | TableFactor::JsonTable { .. }
436        | TableFactor::OpenJsonTable { .. }
437        | TableFactor::XmlTable { .. }
438        | TableFactor::SemanticView { .. } => {}
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use sqlparser::tokenizer::Token;
445
446    use super::*;
447    use crate::dialect::GreptimeDbDialect;
448    use crate::parser::{ParseOptions, ParserContext};
449    use crate::statements::statement::Statement;
450
451    #[test]
452    fn test_location_to_index() {
453        let testcases = vec![
454            "SELECT * FROM t WHERE a = 1",
455            // start or end with newline
456            r"
457SELECT *
458FROM
459t
460WHERE a =
4611
462",
463            r"SELECT *
464FROM
465t
466WHERE a =
4671
468",
469            r"
470SELECT *
471FROM
472t
473WHERE a =
4741",
475        ];
476
477        for sql in testcases {
478            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
479            loop {
480                let token = parser.parser.next_token();
481                if token == Token::EOF {
482                    break;
483                }
484                let span = token.span;
485                let subslice =
486                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
487                assert_eq!(token.to_string(), subslice);
488            }
489        }
490    }
491
492    #[test]
493    fn test_extract_tables_from_tql_query() {
494        let testcases = vec![
495            (
496                r#"
497CREATE FLOW calc_reqs SINK TO cnt_reqs AS
498TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
499                vec!["http_requests".to_string()],
500            ),
501            (
502                r#"
503CREATE FLOW calc_reqs SINK TO cnt_reqs AS
504TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
505                vec!["http_requests".to_string()],
506            ),
507        ];
508
509        for (sql, expected_tables) in testcases {
510            let mut stmts = ParserContext::create_with_dialect(
511                sql,
512                &GreptimeDbDialect {},
513                ParseOptions::default(),
514            )
515            .unwrap();
516            let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
517                unreachable!()
518            };
519
520            let mut tables = extract_tables_from_query(&create_flow.query)
521                .map(|table| format_raw_object_name(&table))
522                .collect_vec();
523            tables.sort();
524            assert_eq!(expected_tables, tables);
525        }
526    }
527
528    #[test]
529    fn test_extract_tables_from_sql_query_with_derived_join() {
530        let sql = r#"
531CREATE FLOW flow_batch_join_subquery SINK TO flow_batch_join_sink
532EVAL INTERVAL '1m' AS
533SELECT a.symbol, b.mark_price
534FROM (
535    SELECT inst_id AS symbol, max(ts) AS mark_iv_ts
536    FROM flow_batch_join_opt_summary
537    GROUP BY inst_id
538) a
539LEFT JOIN (
540    SELECT symbol, max(mark_price) AS mark_price
541    FROM flow_batch_join_market_v5
542    WHERE "type" = 'OPTION_MARK'
543    GROUP BY symbol
544) b ON a.symbol = b.symbol;
545"#;
546        let mut stmts =
547            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
548                .unwrap();
549        let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
550            unreachable!()
551        };
552
553        let mut tables = extract_tables_from_query(&create_flow.query)
554            .map(|table| format_raw_object_name(&table))
555            .collect_vec();
556        tables.sort();
557        assert_eq!(
558            vec![
559                "flow_batch_join_market_v5".to_string(),
560                "flow_batch_join_opt_summary".to_string(),
561            ],
562            tables
563        );
564    }
565
566    #[test]
567    fn test_extract_tables_from_sql_query_with_cte_scopes() {
568        let testcases = vec![
569            (
570                r#"
571WITH source AS (
572    SELECT * FROM source
573)
574SELECT * FROM source;
575"#,
576                vec!["source".to_string()],
577            ),
578            (
579                r#"
580WITH first_cte AS (
581    SELECT * FROM physical_source
582), second_cte AS (
583    SELECT * FROM first_cte
584)
585SELECT * FROM second_cte;
586"#,
587                vec!["physical_source".to_string()],
588            ),
589        ];
590
591        for (sql, expected_tables) in testcases {
592            let mut stmts = ParserContext::create_with_dialect(
593                sql,
594                &GreptimeDbDialect {},
595                ParseOptions::default(),
596            )
597            .unwrap();
598            let Statement::Query(query) = stmts.pop().unwrap() else {
599                unreachable!()
600            };
601
602            let mut tables = HashSet::new();
603            extract_tables_from_sql_query(&query.inner, &mut tables);
604            let mut tables = tables
605                .into_iter()
606                .map(|table| format_raw_object_name(&table))
607                .collect_vec();
608            tables.sort();
609            assert_eq!(expected_tables, tables);
610        }
611    }
612
613    #[test]
614    fn test_extract_tables_from_tql_query_with_schema_matcher() {
615        let sql = r#"
616CREATE FLOW calc_reqs SINK TO cnt_reqs AS
617TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
618        let mut stmts =
619            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
620                .unwrap();
621        let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
622            unreachable!()
623        };
624
625        let mut tables = extract_tables_from_query(&create_flow.query)
626            .map(|table| format_raw_object_name(&table))
627            .collect_vec();
628        tables.sort();
629        assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
630    }
631}