1use std::str::FromStr;
16
17use chrono::{DateTime, Utc};
18use snafu::{OptionExt, ResultExt};
19use sqlparser::ast::{Ident, Query, Value};
20use sqlparser::dialect::Dialect;
21use sqlparser::keywords::Keyword;
22use sqlparser::parser::{Parser, ParserError, ParserOptions};
23use sqlparser::tokenizer::{Token, TokenWithSpan};
24
25use crate::ast::{Expr, ObjectName};
26use crate::error::{self, InvalidSqlSnafu, Result, SyntaxSnafu};
27use crate::parsers::tql_parser;
28use crate::statements::kill::Kill;
29use crate::statements::statement::Statement;
30use crate::statements::transform_statements;
31
32pub const FLOW: &str = "FLOW";
33
34#[derive(Clone, Debug, Default)]
36pub struct ParseOptions {
37 pub scheduled_time: Option<DateTime<Utc>>,
40}
41
42pub struct ParserContext<'a> {
44 pub(crate) parser: Parser<'a>,
45 pub(crate) sql: &'a str,
46 pub(crate) scheduled_time: Option<DateTime<Utc>>,
48}
49
50impl ParserContext<'_> {
51 pub fn new<'a>(dialect: &'a dyn Dialect, sql: &'a str) -> Result<ParserContext<'a>> {
53 let parser = Parser::new(dialect)
54 .with_options(ParserOptions::new().with_trailing_commas(true))
55 .try_with_sql(sql)
56 .context(SyntaxSnafu)?;
57
58 Ok(ParserContext {
59 parser,
60 sql,
61 scheduled_time: None,
62 })
63 }
64
65 pub fn parser_query(&mut self) -> Result<Box<Query>> {
67 self.parser.parse_query().context(SyntaxSnafu)
68 }
69
70 pub fn create_with_dialect(
72 sql: &str,
73 dialect: &dyn Dialect,
74 opts: ParseOptions,
75 ) -> Result<Vec<Statement>> {
76 let mut stmts: Vec<Statement> = Vec::new();
77
78 let mut parser_ctx = ParserContext::new(dialect, sql)?;
79 parser_ctx.scheduled_time = opts.scheduled_time;
80
81 let mut expecting_statement_delimiter = false;
82 loop {
83 while parser_ctx.parser.consume_token(&Token::SemiColon) {
85 expecting_statement_delimiter = false;
86 }
87
88 if parser_ctx.parser.peek_token() == Token::EOF {
89 break;
90 }
91 if expecting_statement_delimiter {
92 return parser_ctx.unsupported(parser_ctx.peek_token_as_string());
93 }
94
95 let statement = parser_ctx.parse_statement()?;
96 stmts.push(statement);
97 expecting_statement_delimiter = true;
98 }
99
100 transform_statements(&mut stmts)?;
101
102 Ok(stmts)
103 }
104
105 pub fn parse_table_name(sql: &str, dialect: &dyn Dialect) -> Result<ObjectName> {
106 let parser = Parser::new(dialect)
107 .with_options(ParserOptions::new().with_trailing_commas(true))
108 .try_with_sql(sql)
109 .context(SyntaxSnafu)?;
110 ParserContext {
111 parser,
112 sql,
113 scheduled_time: None,
114 }
115 .intern_parse_table_name()
116 }
117
118 pub(crate) fn intern_parse_table_name(&mut self) -> Result<ObjectName> {
119 let raw_table_name =
120 self.parser
121 .parse_object_name(false)
122 .context(error::UnexpectedSnafu {
123 expected: "a table name",
124 actual: self.parser.peek_token().to_string(),
125 })?;
126 Self::canonicalize_object_name(raw_table_name)
127 }
128
129 pub fn parse_function(sql: &str, dialect: &dyn Dialect) -> Result<Expr> {
130 let mut parser = Parser::new(dialect)
131 .with_options(ParserOptions::new().with_trailing_commas(true))
132 .try_with_sql(sql)
133 .context(SyntaxSnafu)?;
134
135 let function_name = parser.parse_identifier().context(SyntaxSnafu)?;
136 parser
137 .parse_function(vec![function_name].into())
138 .context(SyntaxSnafu)
139 }
140
141 pub fn parse_statement(&mut self) -> Result<Statement> {
143 match self.parser.peek_token().token {
144 Token::Word(w) => match w.keyword {
145 Keyword::CREATE => {
146 let _ = self.parser.next_token();
147 self.parse_create()
148 }
149
150 Keyword::EXPLAIN => {
151 let _ = self.parser.next_token();
152 self.parse_explain()
153 }
154
155 Keyword::SHOW => {
156 let _ = self.parser.next_token();
157 self.parse_show()
158 }
159
160 Keyword::DELETE => self.parse_delete(),
161
162 Keyword::DESCRIBE | Keyword::DESC => {
163 let _ = self.parser.next_token();
164 self.parse_describe()
165 }
166
167 Keyword::INSERT => self.parse_insert(),
168
169 Keyword::REPLACE => self.parse_replace(),
170
171 Keyword::SELECT | Keyword::VALUES => self.parse_query(),
172
173 Keyword::WITH => self.parse_with_tql(),
174
175 Keyword::ALTER => self.parse_alter(),
176
177 Keyword::DROP => self.parse_drop(),
178
179 Keyword::COPY => self.parse_copy(),
180
181 Keyword::TRUNCATE => self.parse_truncate(),
182
183 Keyword::COMMENT => self.parse_comment(),
184
185 Keyword::SET => self.parse_set_variables(),
186
187 Keyword::ADMIN => self.parse_admin_command(),
188
189 Keyword::NoKeyword
190 if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL =>
191 {
192 self.parse_tql(false)
193 }
194
195 Keyword::DECLARE => self.parse_declare_cursor(),
196
197 Keyword::FETCH => self.parse_fetch_cursor(),
198
199 Keyword::CLOSE => self.parse_close_cursor(),
200
201 Keyword::USE => {
202 let _ = self.parser.next_token();
203
204 let database_name = self.parser.parse_identifier().with_context(|_| {
205 error::UnexpectedSnafu {
206 expected: "a database name",
207 actual: self.peek_token_as_string(),
208 }
209 })?;
210 Ok(Statement::Use(
211 Self::canonicalize_identifier(database_name).value,
212 ))
213 }
214
215 Keyword::KILL => {
216 let _ = self.parser.next_token();
217 let kill = if self.parser.parse_keyword(Keyword::QUERY) {
218 let connection_id_exp =
220 self.parser.parse_number_value().with_context(|_| {
221 error::UnexpectedSnafu {
222 expected: "MySQL numeric connection id",
223 actual: self.peek_token_as_string(),
224 }
225 })?;
226 let Value::Number(s, _) = connection_id_exp.value else {
227 return error::UnexpectedTokenSnafu {
228 expected: "MySQL numeric connection id",
229 actual: connection_id_exp.to_string(),
230 }
231 .fail();
232 };
233
234 let connection_id = u32::from_str(&s).map_err(|_| {
235 error::UnexpectedTokenSnafu {
236 expected: "MySQL numeric connection id",
237 actual: s,
238 }
239 .build()
240 })?;
241 Kill::ConnectionId(connection_id)
242 } else {
243 let process_id_ident =
244 self.parser.parse_literal_string().with_context(|_| {
245 error::UnexpectedSnafu {
246 expected: "process id string literal",
247 actual: self.peek_token_as_string(),
248 }
249 })?;
250 Kill::ProcessId(process_id_ident)
251 };
252
253 Ok(Statement::Kill(kill))
254 }
255
256 _ => self.unsupported(self.peek_token_as_string()),
257 },
258 Token::LParen => self.parse_query(),
259 unexpected => self.unsupported(unexpected.to_string()),
260 }
261 }
262
263 pub fn parse_mysql_prepare_stmt(sql: &str, dialect: &dyn Dialect) -> Result<(String, String)> {
265 ParserContext::new(dialect, sql)?.parse_mysql_prepare()
266 }
267
268 pub fn parse_mysql_execute_stmt(
270 sql: &str,
271 dialect: &dyn Dialect,
272 ) -> Result<(String, Vec<Expr>)> {
273 ParserContext::new(dialect, sql)?.parse_mysql_execute()
274 }
275
276 pub fn parse_mysql_deallocate_stmt(sql: &str, dialect: &dyn Dialect) -> Result<String> {
278 ParserContext::new(dialect, sql)?.parse_deallocate()
279 }
280
281 pub fn unsupported<T>(&self, keyword: String) -> Result<T> {
283 error::UnsupportedSnafu { keyword }.fail()
284 }
285
286 pub(crate) fn expected<T>(&self, expected: &str, found: TokenWithSpan) -> Result<T> {
288 Err(ParserError::ParserError(format!(
289 "Expected {expected}, found: {found}",
290 )))
291 .context(SyntaxSnafu)
292 }
293
294 pub fn matches_keyword(&mut self, expected: Keyword) -> bool {
295 match self.parser.peek_token().token {
296 Token::Word(w) => w.keyword == expected,
297 _ => false,
298 }
299 }
300
301 pub fn consume_token(&mut self, expected: &str) -> bool {
302 if self.peek_token_as_string().to_uppercase() == *expected.to_uppercase() {
303 let _ = self.parser.next_token();
304 true
305 } else {
306 false
307 }
308 }
309
310 #[inline]
311 pub(crate) fn peek_token_as_string(&self) -> String {
312 self.parser.peek_token().to_string()
313 }
314
315 pub fn canonicalize_identifier(ident: Ident) -> Ident {
317 if ident.quote_style.is_some() {
318 ident
319 } else {
320 Ident::new(ident.value.to_lowercase())
321 }
322 }
323
324 pub(crate) fn canonicalize_object_name(object_name: ObjectName) -> Result<ObjectName> {
326 object_name
327 .0
328 .into_iter()
329 .map(|x| {
330 x.as_ident()
331 .cloned()
332 .map(Self::canonicalize_identifier)
333 .with_context(|| InvalidSqlSnafu {
334 msg: format!("not an ident: '{x}'"),
335 })
336 })
337 .collect::<Result<Vec<_>>>()
338 .map(Into::into)
339 }
340
341 pub(crate) fn parse_object_name(&mut self) -> std::result::Result<ObjectName, ParserError> {
346 self.parser.parse_object_name(false)
347 }
348}
349
350#[cfg(test)]
351mod tests {
352
353 use datatypes::prelude::ConcreteDataType;
354 use sqlparser::dialect::MySqlDialect;
355
356 use super::*;
357 use crate::dialect::GreptimeDbDialect;
358 use crate::statements::create::CreateTable;
359 use crate::statements::sql_data_type_to_concrete_data_type;
360
361 fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
362 match ParserContext::create_with_dialect(
363 sql,
364 &GreptimeDbDialect {},
365 ParseOptions::default(),
366 )
367 .unwrap()
368 .pop()
369 .unwrap()
370 {
371 Statement::CreateTable(CreateTable { columns, .. }) => {
372 let ts_col = columns.first().unwrap();
373 assert_eq!(
374 expected_type,
375 sql_data_type_to_concrete_data_type(ts_col.data_type()).unwrap()
376 );
377 }
378 _ => unreachable!(),
379 }
380 }
381
382 #[test]
383 pub fn test_create_table_with_precision() {
384 test_timestamp_precision(
385 "create table demo (ts timestamp time index, cnt int);",
386 ConcreteDataType::timestamp_millisecond_datatype(),
387 );
388 test_timestamp_precision(
389 "create table demo (ts timestamp(0) time index, cnt int);",
390 ConcreteDataType::timestamp_second_datatype(),
391 );
392 test_timestamp_precision(
393 "create table demo (ts timestamp(3) time index, cnt int);",
394 ConcreteDataType::timestamp_millisecond_datatype(),
395 );
396 test_timestamp_precision(
397 "create table demo (ts timestamp(6) time index, cnt int);",
398 ConcreteDataType::timestamp_microsecond_datatype(),
399 );
400 test_timestamp_precision(
401 "create table demo (ts timestamp(9) time index, cnt int);",
402 ConcreteDataType::timestamp_nanosecond_datatype(),
403 );
404 }
405
406 #[test]
407 #[should_panic]
408 pub fn test_create_table_with_invalid_precision() {
409 test_timestamp_precision(
410 "create table demo (ts timestamp(1) time index, cnt int);",
411 ConcreteDataType::timestamp_millisecond_datatype(),
412 );
413 }
414
415 #[test]
416 pub fn test_parse_table_name() {
417 let table_name = "a.b.c";
418
419 let object_name =
420 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
421
422 assert_eq!(object_name.0.len(), 3);
423 assert_eq!(object_name.to_string(), table_name);
424
425 let table_name = "a.b";
426
427 let object_name =
428 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
429
430 assert_eq!(object_name.0.len(), 2);
431 assert_eq!(object_name.to_string(), table_name);
432
433 let table_name = "Test.\"public-test\"";
434
435 let object_name =
436 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
437
438 assert_eq!(object_name.0.len(), 2);
439 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
440
441 let table_name = "HelloWorld";
442
443 let object_name =
444 ParserContext::parse_table_name(table_name, &GreptimeDbDialect {}).unwrap();
445
446 assert_eq!(object_name.0.len(), 1);
447 assert_eq!(object_name.to_string(), table_name.to_ascii_lowercase());
448 }
449
450 #[test]
451 pub fn test_parse_mysql_prepare_stmt() {
452 let sql = "PREPARE stmt1 FROM 'SELECT * FROM t1 WHERE id = ?';";
453 let (stmt_name, stmt) =
454 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
455 assert_eq!(stmt_name, "stmt1");
456 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
457
458 let sql = "PREPARE stmt2 FROM \"SELECT * FROM t1 WHERE id = ?\"";
459 let (stmt_name, stmt) =
460 ParserContext::parse_mysql_prepare_stmt(sql, &MySqlDialect {}).unwrap();
461 assert_eq!(stmt_name, "stmt2");
462 assert_eq!(stmt, "SELECT * FROM t1 WHERE id = ?");
463 }
464
465 #[test]
466 pub fn test_parse_mysql_execute_stmt() {
467 let sql = "EXECUTE stmt1 USING 1, 'hello';";
468 let (stmt_name, params) =
469 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
470 assert_eq!(stmt_name, "stmt1");
471 assert_eq!(params.len(), 2);
472 assert_eq!(params[0].to_string(), "1");
473 assert_eq!(params[1].to_string(), "'hello'");
474
475 let sql = "EXECUTE stmt2;";
476 let (stmt_name, params) =
477 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
478 assert_eq!(stmt_name, "stmt2");
479 assert_eq!(params.len(), 0);
480
481 let sql = "EXECUTE stmt3 USING 231, 'hello', \"2003-03-1\", NULL, ;";
482 let (stmt_name, params) =
483 ParserContext::parse_mysql_execute_stmt(sql, &GreptimeDbDialect {}).unwrap();
484 assert_eq!(stmt_name, "stmt3");
485 assert_eq!(params.len(), 4);
486 assert_eq!(params[0].to_string(), "231");
487 assert_eq!(params[1].to_string(), "'hello'");
488 assert_eq!(params[2].to_string(), "\"2003-03-1\"");
489 assert_eq!(params[3].to_string(), "NULL");
490 }
491
492 #[test]
493 pub fn test_parse_mysql_deallocate_stmt() {
494 let sql = "DEALLOCATE stmt1;";
495 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
496 assert_eq!(stmt_name, "stmt1");
497
498 let sql = "DEALLOCATE stmt2";
499 let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
500 assert_eq!(stmt_name, "stmt2");
501 }
502
503 #[test]
504 pub fn test_parse_kill_query_statement() {
505 use crate::statements::kill::Kill;
506
507 let sql = "KILL QUERY 123";
509 let statements =
510 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
511 .unwrap();
512
513 assert_eq!(statements.len(), 1);
514 match &statements[0] {
515 Statement::Kill(Kill::ConnectionId(connection_id)) => {
516 assert_eq!(*connection_id, 123);
517 }
518 _ => panic!("Expected Kill::ConnectionId statement"),
519 }
520
521 let sql = "KILL QUERY 999999";
523 let statements =
524 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
525 .unwrap();
526
527 assert_eq!(statements.len(), 1);
528 match &statements[0] {
529 Statement::Kill(Kill::ConnectionId(connection_id)) => {
530 assert_eq!(*connection_id, 999999);
531 }
532 _ => panic!("Expected Kill::ConnectionId statement"),
533 }
534 }
535
536 #[test]
537 pub fn test_parse_kill_process_statement() {
538 use crate::statements::kill::Kill;
539
540 let sql = "KILL 'process-123'";
542 let statements =
543 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
544 .unwrap();
545
546 assert_eq!(statements.len(), 1);
547 match &statements[0] {
548 Statement::Kill(Kill::ProcessId(process_id)) => {
549 assert_eq!(process_id, "process-123");
550 }
551 _ => panic!("Expected Kill::ProcessId statement"),
552 }
553
554 let sql = "KILL \"process-456\"";
556 let statements =
557 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
558 .unwrap();
559
560 assert_eq!(statements.len(), 1);
561 match &statements[0] {
562 Statement::Kill(Kill::ProcessId(process_id)) => {
563 assert_eq!(process_id, "process-456");
564 }
565 _ => panic!("Expected Kill::ProcessId statement"),
566 }
567
568 let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
570 let statements =
571 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
572 .unwrap();
573
574 assert_eq!(statements.len(), 1);
575 match &statements[0] {
576 Statement::Kill(Kill::ProcessId(process_id)) => {
577 assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
578 }
579 _ => panic!("Expected Kill::ProcessId statement"),
580 }
581 }
582
583 #[test]
584 pub fn test_parse_kill_statement_errors() {
585 let sql = "KILL QUERY";
587 let result =
588 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
589 assert!(result.is_err());
590
591 let sql = "KILL QUERY 'not-a-number'";
593 let result =
594 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
595 assert!(result.is_err());
596
597 let sql = "KILL";
599 let result =
600 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
601 assert!(result.is_err());
602
603 let sql = "KILL QUERY 4294967296"; let result =
606 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
607 assert!(result.is_err());
608 }
609
610 #[test]
611 pub fn test_parse_kill_statement_edge_cases() {
612 use crate::statements::kill::Kill;
613
614 let sql = "KILL QUERY 0";
616 let statements =
617 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
618 .unwrap();
619
620 assert_eq!(statements.len(), 1);
621 match &statements[0] {
622 Statement::Kill(Kill::ConnectionId(connection_id)) => {
623 assert_eq!(*connection_id, 0);
624 }
625 _ => panic!("Expected Kill::ConnectionId statement"),
626 }
627
628 let sql = "KILL QUERY 4294967295"; let statements =
631 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
632 .unwrap();
633
634 assert_eq!(statements.len(), 1);
635 match &statements[0] {
636 Statement::Kill(Kill::ConnectionId(connection_id)) => {
637 assert_eq!(*connection_id, 4294967295);
638 }
639 _ => panic!("Expected Kill::ConnectionId statement"),
640 }
641
642 let sql = "KILL ''";
644 let statements =
645 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
646 .unwrap();
647
648 assert_eq!(statements.len(), 1);
649 match &statements[0] {
650 Statement::Kill(Kill::ProcessId(process_id)) => {
651 assert_eq!(process_id, "");
652 }
653 _ => panic!("Expected Kill::ProcessId statement"),
654 }
655 }
656
657 #[test]
658 pub fn test_parse_kill_statement_case_insensitive() {
659 use crate::statements::kill::Kill;
660
661 let sql = "kill query 123";
663 let statements =
664 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
665 .unwrap();
666
667 assert_eq!(statements.len(), 1);
668 match &statements[0] {
669 Statement::Kill(Kill::ConnectionId(connection_id)) => {
670 assert_eq!(*connection_id, 123);
671 }
672 _ => panic!("Expected Kill::ConnectionId statement"),
673 }
674
675 let sql = "Kill Query 456";
677 let statements =
678 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
679 .unwrap();
680
681 assert_eq!(statements.len(), 1);
682 match &statements[0] {
683 Statement::Kill(Kill::ConnectionId(connection_id)) => {
684 assert_eq!(*connection_id, 456);
685 }
686 _ => panic!("Expected Kill::ConnectionId statement"),
687 }
688 }
689}