1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21 AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22 Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23 SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24 VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29 Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, TableFactor,
30 TableWithJoins, Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::parser::ParserContext;
37use crate::parsers::with_tql_parser::CteContent;
38use crate::statements::create::SqlOrTql;
39use crate::statements::query::Query;
40use crate::statements::tql::Tql;
41
42const SCHEMA_MATCHER: &str = "__schema__";
43const DATABASE_MATCHER: &str = "__database__";
44
45pub fn format_raw_object_name(name: &ObjectName) -> String {
47 struct Inner<'a> {
48 name: &'a ObjectName,
49 }
50
51 impl Display for Inner<'_> {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 let mut delim = "";
54 for ident in self.name.0.iter() {
55 write!(f, "{delim}")?;
56 delim = ".";
57 write!(f, "{}", ident.to_string_unquoted())?;
58 }
59 Ok(())
60 }
61 }
62
63 format!("{}", Inner { name })
64}
65
66#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
67pub struct OptionValue(Expr);
68
69impl OptionValue {
70 pub(crate) fn try_new(expr: Expr) -> Result<Self> {
71 ensure!(
72 matches!(
73 expr,
74 Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
75 ),
76 InvalidExprAsOptionValueSnafu {
77 error: format!("{expr} not accepted")
78 }
79 );
80 Ok(Self(expr))
81 }
82
83 fn expr_as_string(expr: &Expr) -> Option<&str> {
84 match expr {
85 Expr::Value(ValueWithSpan { value, .. }) => match value {
86 Value::SingleQuotedString(s)
87 | Value::DoubleQuotedString(s)
88 | Value::TripleSingleQuotedString(s)
89 | Value::TripleDoubleQuotedString(s)
90 | Value::SingleQuotedByteStringLiteral(s)
91 | Value::DoubleQuotedByteStringLiteral(s)
92 | Value::TripleSingleQuotedByteStringLiteral(s)
93 | Value::TripleDoubleQuotedByteStringLiteral(s)
94 | Value::SingleQuotedRawStringLiteral(s)
95 | Value::DoubleQuotedRawStringLiteral(s)
96 | Value::TripleSingleQuotedRawStringLiteral(s)
97 | Value::TripleDoubleQuotedRawStringLiteral(s)
98 | Value::EscapedStringLiteral(s)
99 | Value::UnicodeStringLiteral(s)
100 | Value::NationalStringLiteral(s)
101 | Value::HexStringLiteral(s) => Some(s),
102 Value::DollarQuotedString(s) => Some(&s.value),
103 Value::Number(s, _) => Some(s),
104 Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
105 _ => None,
106 },
107 Expr::Identifier(ident) => Some(&ident.value),
108 _ => None,
109 }
110 }
111
112 pub fn as_string(&self) -> Option<&str> {
116 Self::expr_as_string(&self.0)
117 }
118
119 pub fn as_list(&self) -> Option<Vec<&str>> {
120 let expr = &self.0;
121 match expr {
122 Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
123 Expr::Array(array) => array
124 .elem
125 .iter()
126 .map(Self::expr_as_string)
127 .collect::<Option<Vec<_>>>(),
128 _ => None,
129 }
130 }
131}
132
133impl From<String> for OptionValue {
134 fn from(value: String) -> Self {
135 Self(Expr::Identifier(Ident::new(value)))
136 }
137}
138
139impl From<&str> for OptionValue {
140 fn from(value: &str) -> Self {
141 Self(Expr::Identifier(Ident::new(value)))
142 }
143}
144
145impl From<Vec<&str>> for OptionValue {
146 fn from(value: Vec<&str>) -> Self {
147 Self(Expr::Array(Array {
148 elem: value
149 .into_iter()
150 .map(|x| Expr::Identifier(Ident::new(x)))
151 .collect(),
152 named: false,
153 }))
154 }
155}
156
157impl Display for OptionValue {
158 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
159 if let Some(s) = self.as_string() {
160 write!(f, "'{s}'")
161 } else if let Some(s) = self.as_list() {
162 write!(
163 f,
164 "[{}]",
165 s.into_iter().map(|x| format!("'{x}'")).join(", ")
166 )
167 } else {
168 write!(f, "'{}'", self.0)
169 }
170 }
171}
172
173pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
174 let SqlOption::KeyValue { key, value } = option else {
175 return InvalidSqlSnafu {
176 msg: "Expecting a key-value pair in the option",
177 }
178 .fail();
179 };
180 let v = OptionValue::try_new(value)?;
181 let k = key.value.to_lowercase();
182 Ok((k, v))
183}
184
185pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
187 let mut names = HashSet::new();
188
189 match query {
190 SqlOrTql::Sql(query, _) => {
191 extract_tables_from_sql_query(&query.inner, &mut names);
192 extract_tables_from_hybrid_cte_query(query, &mut names);
193 }
194 SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
195 }
196
197 names.into_iter()
198}
199
200fn extract_tables_from_hybrid_cte_query(query: &Query, sql_names: &mut HashSet<ObjectName>) {
201 if let Some(hybrid_cte) = &query.hybrid_cte {
202 let mut cte_names: HashSet<String> = hybrid_cte
203 .cte_tables
204 .iter()
205 .map(|cte| ParserContext::canonicalize_identifier(cte.name.clone()).value)
206 .collect();
207 remove_cte_names(sql_names, &cte_names);
208
209 cte_names.clear();
210 for cte in &hybrid_cte.cte_tables {
211 let cte_name = ParserContext::canonicalize_identifier(cte.name.clone()).value;
212 let mut cte_query_names = HashSet::new();
213 match &cte.content {
214 CteContent::Sql(cte_query) => {
215 extract_tables_from_sql_query(cte_query, &mut cte_query_names)
216 }
217 CteContent::Tql(tql) => extract_tables_from_tql(tql, &mut cte_query_names),
218 }
219 if hybrid_cte.recursive {
220 cte_names.insert(cte_name.clone());
221 }
222 remove_cte_names(&mut cte_query_names, &cte_names);
223 sql_names.extend(cte_query_names);
224 if !hybrid_cte.recursive {
225 cte_names.insert(cte_name);
226 }
227 }
228 }
229}
230
231fn remove_cte_names(names: &mut HashSet<ObjectName>, cte_names: &HashSet<String>) {
232 if cte_names.is_empty() {
233 return;
234 }
235
236 names.retain(|name| {
237 if name.0.len() != 1 {
238 return true;
239 }
240 let Some(ident) = name.0[0].as_ident() else {
241 return true;
242 };
243
244 let canonical = ParserContext::canonicalize_identifier(ident.clone()).value;
245 !cte_names.contains(&canonical)
246 });
247}
248
249fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
250 let promql = match tql {
251 Tql::Eval(eval) => &eval.query,
252 Tql::Explain(explain) => &explain.query,
253 Tql::Analyze(analyze) => &analyze.query,
254 };
255
256 if let Ok(expr) = promql_parser::parser::parse(promql) {
257 extract_tables_from_prom_expr(&expr, names);
258 }
259}
260
261fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
262 match expr {
263 PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
264 extract_tables_from_prom_expr(expr, names);
265 }
266 PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
267 extract_tables_from_prom_expr(expr, names);
268 }
269 PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
270 extract_tables_from_prom_expr(lhs, names);
271 extract_tables_from_prom_expr(rhs, names);
272 }
273 PromExpr::Paren(PromParenExpr { expr }) => {
274 extract_tables_from_prom_expr(expr, names);
275 }
276 PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
277 extract_tables_from_prom_expr(expr, names);
278 }
279 PromExpr::VectorSelector(selector) => {
280 extract_metric_name_from_vector_selector(selector, names);
281 }
282 PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
283 extract_metric_name_from_vector_selector(vs, names);
284 }
285 PromExpr::Call(PromCall { args, .. }) => {
286 for arg in &args.args {
287 extract_tables_from_prom_expr(arg, names);
288 }
289 }
290 PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
291 }
292}
293
294fn extract_metric_name_from_vector_selector(
295 selector: &PromVectorSelector,
296 names: &mut HashSet<ObjectName>,
297) {
298 let metric_name = selector.name.clone().or_else(|| {
299 let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
300 if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
301 metric_name_matchers.pop().map(|matcher| matcher.value)
302 } else {
303 None
304 }
305 });
306 let Some(metric_name) = metric_name else {
307 return;
308 };
309
310 let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
311 matcher.op == MatchOp::Equal
312 && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
313 });
314
315 if let Some(schema) = schema_matcher {
316 names.insert(ObjectName(vec![
317 ObjectNamePart::Identifier(Ident::new(&schema.value)),
318 ObjectNamePart::Identifier(Ident::new(metric_name)),
319 ]));
320 } else {
321 names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
322 metric_name,
323 ))]));
324 }
325}
326
327pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
329 let mut index = 0;
330 for (lno, line) in sql.lines().enumerate() {
331 if lno + 1 == location.line as usize {
332 index += location.column as usize;
333 break;
334 } else {
335 index += line.len() + 1; }
337 }
338 index - 1
341}
342
343fn extract_tables_from_sql_query(query: &sqlparser::ast::Query, names: &mut HashSet<ObjectName>) {
347 let mut cte_names = HashSet::new();
348 if let Some(with) = &query.with {
349 for cte in &with.cte_tables {
350 let cte_name = ParserContext::canonicalize_identifier(cte.alias.name.clone()).value;
351 let mut cte_query_names = HashSet::new();
352 extract_tables_from_sql_query(&cte.query, &mut cte_query_names);
353 if with.recursive {
354 cte_names.insert(cte_name.clone());
355 }
356 remove_cte_names(&mut cte_query_names, &cte_names);
357 names.extend(cte_query_names);
358 if !with.recursive {
359 cte_names.insert(cte_name);
360 }
361 }
362 }
363
364 let mut body_names = HashSet::new();
365 extract_tables_from_set_expr(&query.body, &mut body_names);
366 remove_cte_names(&mut body_names, &cte_names);
367 names.extend(body_names);
368}
369
370fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
374 match set_expr {
375 SetExpr::Select(select) => {
376 for from in &select.from {
377 extract_tables_from_table_with_joins(from, names);
378 }
379 }
380 SetExpr::Query(query) => {
381 extract_tables_from_sql_query(query, names);
382 }
383 SetExpr::SetOperation { left, right, .. } => {
384 extract_tables_from_set_expr(left, names);
385 extract_tables_from_set_expr(right, names);
386 }
387 _ => {}
388 };
389}
390
391fn extract_tables_from_table_with_joins(
395 table_with_joins: &TableWithJoins,
396 names: &mut HashSet<ObjectName>,
397) {
398 table_factor_to_object_name(&table_with_joins.relation, names);
399 for join in &table_with_joins.joins {
400 table_factor_to_object_name(&join.relation, names);
401 }
402}
403
404fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
408 match table_factor {
409 TableFactor::Table { name, .. } => {
410 names.insert(name.to_owned());
411 }
412 TableFactor::Derived { subquery, .. } => {
413 extract_tables_from_sql_query(subquery, names);
414 }
415 TableFactor::NestedJoin {
416 table_with_joins, ..
417 } => {
418 extract_tables_from_table_with_joins(table_with_joins, names);
419 }
420 TableFactor::Pivot { table, .. }
421 | TableFactor::Unpivot { table, .. }
422 | TableFactor::MatchRecognize { table, .. } => {
423 table_factor_to_object_name(table, names);
424 }
425 TableFactor::TableFunction { .. }
426 | TableFactor::Function { .. }
427 | TableFactor::UNNEST { .. }
428 | TableFactor::JsonTable { .. }
429 | TableFactor::OpenJsonTable { .. }
430 | TableFactor::XmlTable { .. }
431 | TableFactor::SemanticView { .. } => {}
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use sqlparser::tokenizer::Token;
438
439 use super::*;
440 use crate::dialect::GreptimeDbDialect;
441 use crate::parser::{ParseOptions, ParserContext};
442 use crate::statements::statement::Statement;
443
444 #[test]
445 fn test_location_to_index() {
446 let testcases = vec![
447 "SELECT * FROM t WHERE a = 1",
448 r"
450SELECT *
451FROM
452t
453WHERE a =
4541
455",
456 r"SELECT *
457FROM
458t
459WHERE a =
4601
461",
462 r"
463SELECT *
464FROM
465t
466WHERE a =
4671",
468 ];
469
470 for sql in testcases {
471 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
472 loop {
473 let token = parser.parser.next_token();
474 if token == Token::EOF {
475 break;
476 }
477 let span = token.span;
478 let subslice =
479 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
480 assert_eq!(token.to_string(), subslice);
481 }
482 }
483 }
484
485 #[test]
486 fn test_extract_tables_from_tql_query() {
487 let testcases = vec![
488 (
489 r#"
490CREATE FLOW calc_reqs SINK TO cnt_reqs AS
491TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
492 vec!["http_requests".to_string()],
493 ),
494 (
495 r#"
496CREATE FLOW calc_reqs SINK TO cnt_reqs AS
497TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
498 vec!["http_requests".to_string()],
499 ),
500 ];
501
502 for (sql, expected_tables) in testcases {
503 let mut stmts = ParserContext::create_with_dialect(
504 sql,
505 &GreptimeDbDialect {},
506 ParseOptions::default(),
507 )
508 .unwrap();
509 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
510 unreachable!()
511 };
512
513 let mut tables = extract_tables_from_query(&create_flow.query)
514 .map(|table| format_raw_object_name(&table))
515 .collect_vec();
516 tables.sort();
517 assert_eq!(expected_tables, tables);
518 }
519 }
520
521 #[test]
522 fn test_extract_tables_from_sql_query_with_derived_join() {
523 let sql = r#"
524CREATE FLOW flow_batch_join_subquery SINK TO flow_batch_join_sink
525EVAL INTERVAL '1m' AS
526SELECT a.symbol, b.mark_price
527FROM (
528 SELECT inst_id AS symbol, max(ts) AS mark_iv_ts
529 FROM flow_batch_join_opt_summary
530 GROUP BY inst_id
531) a
532LEFT JOIN (
533 SELECT symbol, max(mark_price) AS mark_price
534 FROM flow_batch_join_market_v5
535 WHERE "type" = 'OPTION_MARK'
536 GROUP BY symbol
537) b ON a.symbol = b.symbol;
538"#;
539 let mut stmts =
540 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
541 .unwrap();
542 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
543 unreachable!()
544 };
545
546 let mut tables = extract_tables_from_query(&create_flow.query)
547 .map(|table| format_raw_object_name(&table))
548 .collect_vec();
549 tables.sort();
550 assert_eq!(
551 vec![
552 "flow_batch_join_market_v5".to_string(),
553 "flow_batch_join_opt_summary".to_string(),
554 ],
555 tables
556 );
557 }
558
559 #[test]
560 fn test_extract_tables_from_sql_query_with_cte_scopes() {
561 let testcases = vec![
562 (
563 r#"
564WITH source AS (
565 SELECT * FROM source
566)
567SELECT * FROM source;
568"#,
569 vec!["source".to_string()],
570 ),
571 (
572 r#"
573WITH first_cte AS (
574 SELECT * FROM physical_source
575), second_cte AS (
576 SELECT * FROM first_cte
577)
578SELECT * FROM second_cte;
579"#,
580 vec!["physical_source".to_string()],
581 ),
582 ];
583
584 for (sql, expected_tables) in testcases {
585 let mut stmts = ParserContext::create_with_dialect(
586 sql,
587 &GreptimeDbDialect {},
588 ParseOptions::default(),
589 )
590 .unwrap();
591 let Statement::Query(query) = stmts.pop().unwrap() else {
592 unreachable!()
593 };
594
595 let mut tables = HashSet::new();
596 extract_tables_from_sql_query(&query.inner, &mut tables);
597 let mut tables = tables
598 .into_iter()
599 .map(|table| format_raw_object_name(&table))
600 .collect_vec();
601 tables.sort();
602 assert_eq!(expected_tables, tables);
603 }
604 }
605
606 #[test]
607 fn test_extract_tables_from_tql_query_with_schema_matcher() {
608 let sql = r#"
609CREATE FLOW calc_reqs SINK TO cnt_reqs AS
610TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
611 let mut stmts =
612 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
613 .unwrap();
614 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
615 unreachable!()
616 };
617
618 let mut tables = extract_tables_from_query(&create_flow.query)
619 .map(|table| format_raw_object_name(&table))
620 .collect_vec();
621 tables.sort();
622 assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
623 }
624}