feat: sql dialect for different protocols (#1631)

* feat: add SqlDialect to query context

* feat: use session in postgrel handlers

* chore: refactor sql dialect

* feat: use different dialects for different sql protocols

* feat: adds GreptimeDbDialect

* refactor: replace GenericDialect with GreptimeDbDialect

* feat: save user info to session

* fix: compile error

* fix: test
This commit is contained in:
dennis zhuang
2023-05-30 09:52:35 +08:00
committed by GitHub
parent 563ce59071
commit ab5dfd31ec
31 changed files with 285 additions and 185 deletions

View File

@@ -322,7 +322,7 @@ pub(crate) fn to_alter_expr(
#[cfg(test)]
mod tests {
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -331,7 +331,7 @@ mod tests {
#[test]
fn test_create_to_expr() {
let sql = "CREATE TABLE monitor (host STRING,ts TIMESTAMP,TIME INDEX (ts),PRIMARY KEY(host)) ENGINE=mito WITH(regions=1, ttl='3days', write_buffer_size='1024KB');";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.pop()
.unwrap();

View File

@@ -64,7 +64,7 @@ use servers::query_handler::{
};
use session::context::QueryContextRef;
use snafu::prelude::*;
use sql::dialect::GenericDialect;
use sql::dialect::Dialect;
use sql::parser::ParserContext;
use sql::statements::copy::CopyTable;
use sql::statements::statement::Statement;
@@ -447,8 +447,8 @@ impl FrontendInstance for Instance {
}
}
fn parse_stmt(sql: &str) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, &GenericDialect {}).context(ParseSqlSnafu)
fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, dialect).context(ParseSqlSnafu)
}
impl Instance {
@@ -473,7 +473,7 @@ impl SqlQueryHandler for Instance {
Err(e) => return vec![Err(e)],
};
match parse_stmt(query.as_ref())
match parse_stmt(query.as_ref(), query_ctx.sql_dialect())
.and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
{
Ok(stmts) => {
@@ -664,6 +664,7 @@ mod tests {
use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema};
use query::query_engine::options::QueryOptions;
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use strfmt::Format;
use super::*;
@@ -748,7 +749,7 @@ mod tests {
CREATE DATABASE test_database;
SHOW DATABASES;
"#;
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(stmts.len(), 4);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
@@ -759,7 +760,7 @@ mod tests {
SHOW CREATE TABLE demo;
ALTER TABLE demo ADD COLUMN new_col INT;
"#;
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(stmts.len(), 2);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
@@ -767,7 +768,7 @@ mod tests {
}
let sql = "USE randomschema";
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmts[0], &query_ctx);
assert!(re.is_ok());
@@ -800,7 +801,7 @@ mod tests {
}
fn do_test(sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef, is_ok: bool) {
let stmt = &parse_stmt(sql).unwrap()[0];
let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0];
let re = check_permission(plugins, stmt, query_ctx);
if is_ok {
assert!(re.is_ok());
@@ -828,12 +829,12 @@ mod tests {
// test show tables
let sql = "SHOW TABLES FROM public";
let stmt = parse_stmt(sql).unwrap();
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_ok());
let sql = "SHOW TABLES FROM wrongschema";
let stmt = parse_stmt(sql).unwrap();
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_err());

View File

@@ -874,7 +874,7 @@ fn find_partition_columns(
#[cfg(test)]
mod test {
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -908,7 +908,7 @@ ENGINE=mito",
),
];
for (sql, expected) in cases {
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
match &result[0] {
Statement::CreateTable(c) => {
let expr = expr_factory::create_to_expr(c, QueryContext::arc()).unwrap();