1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21 AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22 Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23 SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24 VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29 Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, StructField, TableFactor,
30 TableWithJoins, Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::parser::ParserContext;
37use crate::parsers::with_tql_parser::CteContent;
38use crate::statements::create::SqlOrTql;
39use crate::statements::query::Query;
40use crate::statements::tql::Tql;
41
42const SCHEMA_MATCHER: &str = "__schema__";
43const DATABASE_MATCHER: &str = "__database__";
44
45pub fn format_raw_object_name(name: &ObjectName) -> String {
47 struct Inner<'a> {
48 name: &'a ObjectName,
49 }
50
51 impl Display for Inner<'_> {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 let mut delim = "";
54 for ident in self.name.0.iter() {
55 write!(f, "{delim}")?;
56 delim = ".";
57 write!(f, "{}", ident.to_string_unquoted())?;
58 }
59 Ok(())
60 }
61 }
62
63 format!("{}", Inner { name })
64}
65
66#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
67pub struct OptionValue(Expr);
68
69impl OptionValue {
70 pub(crate) fn try_new(expr: Expr) -> Result<Self> {
71 ensure!(
72 matches!(
73 expr,
74 Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
75 ),
76 InvalidExprAsOptionValueSnafu {
77 error: format!("{expr} not accepted")
78 }
79 );
80 Ok(Self(expr))
81 }
82
83 fn expr_as_string(expr: &Expr) -> Option<&str> {
84 match expr {
85 Expr::Value(ValueWithSpan { value, .. }) => match value {
86 Value::SingleQuotedString(s)
87 | Value::DoubleQuotedString(s)
88 | Value::TripleSingleQuotedString(s)
89 | Value::TripleDoubleQuotedString(s)
90 | Value::SingleQuotedByteStringLiteral(s)
91 | Value::DoubleQuotedByteStringLiteral(s)
92 | Value::TripleSingleQuotedByteStringLiteral(s)
93 | Value::TripleDoubleQuotedByteStringLiteral(s)
94 | Value::SingleQuotedRawStringLiteral(s)
95 | Value::DoubleQuotedRawStringLiteral(s)
96 | Value::TripleSingleQuotedRawStringLiteral(s)
97 | Value::TripleDoubleQuotedRawStringLiteral(s)
98 | Value::EscapedStringLiteral(s)
99 | Value::UnicodeStringLiteral(s)
100 | Value::NationalStringLiteral(s)
101 | Value::HexStringLiteral(s) => Some(s),
102 Value::DollarQuotedString(s) => Some(&s.value),
103 Value::Number(s, _) => Some(s),
104 Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
105 _ => None,
106 },
107 Expr::Identifier(ident) => Some(&ident.value),
108 _ => None,
109 }
110 }
111
112 pub fn as_string(&self) -> Option<&str> {
116 Self::expr_as_string(&self.0)
117 }
118
119 pub fn as_list(&self) -> Option<Vec<&str>> {
120 let expr = &self.0;
121 match expr {
122 Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
123 Expr::Array(array) => array
124 .elem
125 .iter()
126 .map(Self::expr_as_string)
127 .collect::<Option<Vec<_>>>(),
128 _ => None,
129 }
130 }
131
132 pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
133 match &self.0 {
134 Expr::Struct { fields, .. } => Some(fields),
135 _ => None,
136 }
137 }
138}
139
140impl From<String> for OptionValue {
141 fn from(value: String) -> Self {
142 Self(Expr::Identifier(Ident::new(value)))
143 }
144}
145
146impl From<&str> for OptionValue {
147 fn from(value: &str) -> Self {
148 Self(Expr::Identifier(Ident::new(value)))
149 }
150}
151
152impl From<Vec<&str>> for OptionValue {
153 fn from(value: Vec<&str>) -> Self {
154 Self(Expr::Array(Array {
155 elem: value
156 .into_iter()
157 .map(|x| Expr::Identifier(Ident::new(x)))
158 .collect(),
159 named: false,
160 }))
161 }
162}
163
164impl Display for OptionValue {
165 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166 if let Some(s) = self.as_string() {
167 write!(f, "'{s}'")
168 } else if let Some(s) = self.as_list() {
169 write!(
170 f,
171 "[{}]",
172 s.into_iter().map(|x| format!("'{x}'")).join(", ")
173 )
174 } else {
175 write!(f, "'{}'", self.0)
176 }
177 }
178}
179
180pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
181 let SqlOption::KeyValue { key, value } = option else {
182 return InvalidSqlSnafu {
183 msg: "Expecting a key-value pair in the option",
184 }
185 .fail();
186 };
187 let v = OptionValue::try_new(value)?;
188 let k = key.value.to_lowercase();
189 Ok((k, v))
190}
191
192pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
194 let mut names = HashSet::new();
195
196 match query {
197 SqlOrTql::Sql(query, _) => {
198 extract_tables_from_sql_query(&query.inner, &mut names);
199 extract_tables_from_hybrid_cte_query(query, &mut names);
200 }
201 SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
202 }
203
204 names.into_iter()
205}
206
207fn extract_tables_from_hybrid_cte_query(query: &Query, sql_names: &mut HashSet<ObjectName>) {
208 if let Some(hybrid_cte) = &query.hybrid_cte {
209 let mut cte_names: HashSet<String> = hybrid_cte
210 .cte_tables
211 .iter()
212 .map(|cte| ParserContext::canonicalize_identifier(cte.name.clone()).value)
213 .collect();
214 remove_cte_names(sql_names, &cte_names);
215
216 cte_names.clear();
217 for cte in &hybrid_cte.cte_tables {
218 let cte_name = ParserContext::canonicalize_identifier(cte.name.clone()).value;
219 let mut cte_query_names = HashSet::new();
220 match &cte.content {
221 CteContent::Sql(cte_query) => {
222 extract_tables_from_sql_query(cte_query, &mut cte_query_names)
223 }
224 CteContent::Tql(tql) => extract_tables_from_tql(tql, &mut cte_query_names),
225 }
226 if hybrid_cte.recursive {
227 cte_names.insert(cte_name.clone());
228 }
229 remove_cte_names(&mut cte_query_names, &cte_names);
230 sql_names.extend(cte_query_names);
231 if !hybrid_cte.recursive {
232 cte_names.insert(cte_name);
233 }
234 }
235 }
236}
237
238fn remove_cte_names(names: &mut HashSet<ObjectName>, cte_names: &HashSet<String>) {
239 if cte_names.is_empty() {
240 return;
241 }
242
243 names.retain(|name| {
244 if name.0.len() != 1 {
245 return true;
246 }
247 let Some(ident) = name.0[0].as_ident() else {
248 return true;
249 };
250
251 let canonical = ParserContext::canonicalize_identifier(ident.clone()).value;
252 !cte_names.contains(&canonical)
253 });
254}
255
256fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
257 let promql = match tql {
258 Tql::Eval(eval) => &eval.query,
259 Tql::Explain(explain) => &explain.query,
260 Tql::Analyze(analyze) => &analyze.query,
261 };
262
263 if let Ok(expr) = promql_parser::parser::parse(promql) {
264 extract_tables_from_prom_expr(&expr, names);
265 }
266}
267
268fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
269 match expr {
270 PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
271 extract_tables_from_prom_expr(expr, names);
272 }
273 PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
274 extract_tables_from_prom_expr(expr, names);
275 }
276 PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
277 extract_tables_from_prom_expr(lhs, names);
278 extract_tables_from_prom_expr(rhs, names);
279 }
280 PromExpr::Paren(PromParenExpr { expr }) => {
281 extract_tables_from_prom_expr(expr, names);
282 }
283 PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
284 extract_tables_from_prom_expr(expr, names);
285 }
286 PromExpr::VectorSelector(selector) => {
287 extract_metric_name_from_vector_selector(selector, names);
288 }
289 PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
290 extract_metric_name_from_vector_selector(vs, names);
291 }
292 PromExpr::Call(PromCall { args, .. }) => {
293 for arg in &args.args {
294 extract_tables_from_prom_expr(arg, names);
295 }
296 }
297 PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
298 }
299}
300
301fn extract_metric_name_from_vector_selector(
302 selector: &PromVectorSelector,
303 names: &mut HashSet<ObjectName>,
304) {
305 let metric_name = selector.name.clone().or_else(|| {
306 let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
307 if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
308 metric_name_matchers.pop().map(|matcher| matcher.value)
309 } else {
310 None
311 }
312 });
313 let Some(metric_name) = metric_name else {
314 return;
315 };
316
317 let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
318 matcher.op == MatchOp::Equal
319 && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
320 });
321
322 if let Some(schema) = schema_matcher {
323 names.insert(ObjectName(vec![
324 ObjectNamePart::Identifier(Ident::new(&schema.value)),
325 ObjectNamePart::Identifier(Ident::new(metric_name)),
326 ]));
327 } else {
328 names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
329 metric_name,
330 ))]));
331 }
332}
333
334pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
336 let mut index = 0;
337 for (lno, line) in sql.lines().enumerate() {
338 if lno + 1 == location.line as usize {
339 index += location.column as usize;
340 break;
341 } else {
342 index += line.len() + 1; }
344 }
345 index - 1
348}
349
350fn extract_tables_from_sql_query(query: &sqlparser::ast::Query, names: &mut HashSet<ObjectName>) {
354 let mut cte_names = HashSet::new();
355 if let Some(with) = &query.with {
356 for cte in &with.cte_tables {
357 let cte_name = ParserContext::canonicalize_identifier(cte.alias.name.clone()).value;
358 let mut cte_query_names = HashSet::new();
359 extract_tables_from_sql_query(&cte.query, &mut cte_query_names);
360 if with.recursive {
361 cte_names.insert(cte_name.clone());
362 }
363 remove_cte_names(&mut cte_query_names, &cte_names);
364 names.extend(cte_query_names);
365 if !with.recursive {
366 cte_names.insert(cte_name);
367 }
368 }
369 }
370
371 let mut body_names = HashSet::new();
372 extract_tables_from_set_expr(&query.body, &mut body_names);
373 remove_cte_names(&mut body_names, &cte_names);
374 names.extend(body_names);
375}
376
377fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
381 match set_expr {
382 SetExpr::Select(select) => {
383 for from in &select.from {
384 extract_tables_from_table_with_joins(from, names);
385 }
386 }
387 SetExpr::Query(query) => {
388 extract_tables_from_sql_query(query, names);
389 }
390 SetExpr::SetOperation { left, right, .. } => {
391 extract_tables_from_set_expr(left, names);
392 extract_tables_from_set_expr(right, names);
393 }
394 _ => {}
395 };
396}
397
398fn extract_tables_from_table_with_joins(
402 table_with_joins: &TableWithJoins,
403 names: &mut HashSet<ObjectName>,
404) {
405 table_factor_to_object_name(&table_with_joins.relation, names);
406 for join in &table_with_joins.joins {
407 table_factor_to_object_name(&join.relation, names);
408 }
409}
410
411fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
415 match table_factor {
416 TableFactor::Table { name, .. } => {
417 names.insert(name.to_owned());
418 }
419 TableFactor::Derived { subquery, .. } => {
420 extract_tables_from_sql_query(subquery, names);
421 }
422 TableFactor::NestedJoin {
423 table_with_joins, ..
424 } => {
425 extract_tables_from_table_with_joins(table_with_joins, names);
426 }
427 TableFactor::Pivot { table, .. }
428 | TableFactor::Unpivot { table, .. }
429 | TableFactor::MatchRecognize { table, .. } => {
430 table_factor_to_object_name(table, names);
431 }
432 TableFactor::TableFunction { .. }
433 | TableFactor::Function { .. }
434 | TableFactor::UNNEST { .. }
435 | TableFactor::JsonTable { .. }
436 | TableFactor::OpenJsonTable { .. }
437 | TableFactor::XmlTable { .. }
438 | TableFactor::SemanticView { .. } => {}
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use sqlparser::tokenizer::Token;
445
446 use super::*;
447 use crate::dialect::GreptimeDbDialect;
448 use crate::parser::{ParseOptions, ParserContext};
449 use crate::statements::statement::Statement;
450
451 #[test]
452 fn test_location_to_index() {
453 let testcases = vec![
454 "SELECT * FROM t WHERE a = 1",
455 r"
457SELECT *
458FROM
459t
460WHERE a =
4611
462",
463 r"SELECT *
464FROM
465t
466WHERE a =
4671
468",
469 r"
470SELECT *
471FROM
472t
473WHERE a =
4741",
475 ];
476
477 for sql in testcases {
478 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
479 loop {
480 let token = parser.parser.next_token();
481 if token == Token::EOF {
482 break;
483 }
484 let span = token.span;
485 let subslice =
486 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
487 assert_eq!(token.to_string(), subslice);
488 }
489 }
490 }
491
492 #[test]
493 fn test_extract_tables_from_tql_query() {
494 let testcases = vec![
495 (
496 r#"
497CREATE FLOW calc_reqs SINK TO cnt_reqs AS
498TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
499 vec!["http_requests".to_string()],
500 ),
501 (
502 r#"
503CREATE FLOW calc_reqs SINK TO cnt_reqs AS
504TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
505 vec!["http_requests".to_string()],
506 ),
507 ];
508
509 for (sql, expected_tables) in testcases {
510 let mut stmts = ParserContext::create_with_dialect(
511 sql,
512 &GreptimeDbDialect {},
513 ParseOptions::default(),
514 )
515 .unwrap();
516 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
517 unreachable!()
518 };
519
520 let mut tables = extract_tables_from_query(&create_flow.query)
521 .map(|table| format_raw_object_name(&table))
522 .collect_vec();
523 tables.sort();
524 assert_eq!(expected_tables, tables);
525 }
526 }
527
528 #[test]
529 fn test_extract_tables_from_sql_query_with_derived_join() {
530 let sql = r#"
531CREATE FLOW flow_batch_join_subquery SINK TO flow_batch_join_sink
532EVAL INTERVAL '1m' AS
533SELECT a.symbol, b.mark_price
534FROM (
535 SELECT inst_id AS symbol, max(ts) AS mark_iv_ts
536 FROM flow_batch_join_opt_summary
537 GROUP BY inst_id
538) a
539LEFT JOIN (
540 SELECT symbol, max(mark_price) AS mark_price
541 FROM flow_batch_join_market_v5
542 WHERE "type" = 'OPTION_MARK'
543 GROUP BY symbol
544) b ON a.symbol = b.symbol;
545"#;
546 let mut stmts =
547 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
548 .unwrap();
549 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
550 unreachable!()
551 };
552
553 let mut tables = extract_tables_from_query(&create_flow.query)
554 .map(|table| format_raw_object_name(&table))
555 .collect_vec();
556 tables.sort();
557 assert_eq!(
558 vec![
559 "flow_batch_join_market_v5".to_string(),
560 "flow_batch_join_opt_summary".to_string(),
561 ],
562 tables
563 );
564 }
565
566 #[test]
567 fn test_extract_tables_from_sql_query_with_cte_scopes() {
568 let testcases = vec![
569 (
570 r#"
571WITH source AS (
572 SELECT * FROM source
573)
574SELECT * FROM source;
575"#,
576 vec!["source".to_string()],
577 ),
578 (
579 r#"
580WITH first_cte AS (
581 SELECT * FROM physical_source
582), second_cte AS (
583 SELECT * FROM first_cte
584)
585SELECT * FROM second_cte;
586"#,
587 vec!["physical_source".to_string()],
588 ),
589 ];
590
591 for (sql, expected_tables) in testcases {
592 let mut stmts = ParserContext::create_with_dialect(
593 sql,
594 &GreptimeDbDialect {},
595 ParseOptions::default(),
596 )
597 .unwrap();
598 let Statement::Query(query) = stmts.pop().unwrap() else {
599 unreachable!()
600 };
601
602 let mut tables = HashSet::new();
603 extract_tables_from_sql_query(&query.inner, &mut tables);
604 let mut tables = tables
605 .into_iter()
606 .map(|table| format_raw_object_name(&table))
607 .collect_vec();
608 tables.sort();
609 assert_eq!(expected_tables, tables);
610 }
611 }
612
613 #[test]
614 fn test_extract_tables_from_tql_query_with_schema_matcher() {
615 let sql = r#"
616CREATE FLOW calc_reqs SINK TO cnt_reqs AS
617TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
618 let mut stmts =
619 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
620 .unwrap();
621 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
622 unreachable!()
623 };
624
625 let mut tables = extract_tables_from_query(&create_flow.query)
626 .map(|table| format_raw_object_name(&table))
627 .collect_vec();
628 tables.sort();
629 assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
630 }
631}