feat: sql dialect for different protocols (#1631)

* feat: add SqlDialect to query context

* feat: use session in postgrel handlers

* chore: refactor sql dialect

* feat: use different dialects for different sql protocols

* feat: adds GreptimeDbDialect

* refactor: replace GenericDialect with GreptimeDbDialect

* feat: save user info to session

* fix: compile error

* fix: test
This commit is contained in:
dennis zhuang
2023-05-30 09:52:35 +08:00
committed by GitHub
parent 563ce59071
commit ab5dfd31ec
31 changed files with 285 additions and 185 deletions

2
Cargo.lock generated
View File

@@ -8216,6 +8216,7 @@ dependencies = [
"common-catalog",
"common-telemetry",
"common-time",
"sql",
]
[[package]]
@@ -8487,7 +8488,6 @@ name = "sql"
version = "0.2.0"
dependencies = [
"api",
"catalog",
"common-base",
"common-catalog",
"common-datasource",

View File

@@ -100,7 +100,7 @@ mod tests {
use query::parser::{QueryLanguageParser, QueryStatement};
use query::query_engine::SqlStatementExecutor;
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -108,7 +108,7 @@ mod tests {
use crate::tests::test_util::MockInstance;
fn parse_sql(sql: &str) -> AlterTable {
let mut stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmt.len());
let stmt = stmt.remove(0);
assert_matches!(stmt, Statement::Alter(_));

View File

@@ -253,7 +253,7 @@ mod tests {
use query::parser::{QueryLanguageParser, QueryStatement};
use query::query_engine::SqlStatementExecutor;
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -262,7 +262,7 @@ mod tests {
use crate::tests::test_util::MockInstance;
fn sql_to_statement(sql: &str) -> CreateTable {
let mut res = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut res = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, res.len());
match res.pop().unwrap() {
Statement::CreateTable(c) => c,

View File

@@ -322,7 +322,7 @@ pub(crate) fn to_alter_expr(
#[cfg(test)]
mod tests {
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -331,7 +331,7 @@ mod tests {
#[test]
fn test_create_to_expr() {
let sql = "CREATE TABLE monitor (host STRING,ts TIMESTAMP,TIME INDEX (ts),PRIMARY KEY(host)) ENGINE=mito WITH(regions=1, ttl='3days', write_buffer_size='1024KB');";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.pop()
.unwrap();

View File

@@ -64,7 +64,7 @@ use servers::query_handler::{
};
use session::context::QueryContextRef;
use snafu::prelude::*;
use sql::dialect::GenericDialect;
use sql::dialect::Dialect;
use sql::parser::ParserContext;
use sql::statements::copy::CopyTable;
use sql::statements::statement::Statement;
@@ -447,8 +447,8 @@ impl FrontendInstance for Instance {
}
}
fn parse_stmt(sql: &str) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, &GenericDialect {}).context(ParseSqlSnafu)
fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, dialect).context(ParseSqlSnafu)
}
impl Instance {
@@ -473,7 +473,7 @@ impl SqlQueryHandler for Instance {
Err(e) => return vec![Err(e)],
};
match parse_stmt(query.as_ref())
match parse_stmt(query.as_ref(), query_ctx.sql_dialect())
.and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
{
Ok(stmts) => {
@@ -664,6 +664,7 @@ mod tests {
use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema};
use query::query_engine::options::QueryOptions;
use session::context::QueryContext;
use sql::dialect::GreptimeDbDialect;
use strfmt::Format;
use super::*;
@@ -748,7 +749,7 @@ mod tests {
CREATE DATABASE test_database;
SHOW DATABASES;
"#;
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(stmts.len(), 4);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
@@ -759,7 +760,7 @@ mod tests {
SHOW CREATE TABLE demo;
ALTER TABLE demo ADD COLUMN new_col INT;
"#;
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(stmts.len(), 2);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
@@ -767,7 +768,7 @@ mod tests {
}
let sql = "USE randomschema";
let stmts = parse_stmt(sql).unwrap();
let stmts = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmts[0], &query_ctx);
assert!(re.is_ok());
@@ -800,7 +801,7 @@ mod tests {
}
fn do_test(sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef, is_ok: bool) {
let stmt = &parse_stmt(sql).unwrap()[0];
let stmt = &parse_stmt(sql, &GreptimeDbDialect {}).unwrap()[0];
let re = check_permission(plugins, stmt, query_ctx);
if is_ok {
assert!(re.is_ok());
@@ -828,12 +829,12 @@ mod tests {
// test show tables
let sql = "SHOW TABLES FROM public";
let stmt = parse_stmt(sql).unwrap();
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_ok());
let sql = "SHOW TABLES FROM wrongschema";
let stmt = parse_stmt(sql).unwrap();
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_err());

View File

@@ -874,7 +874,7 @@ fn find_partition_columns(
#[cfg(test)]
mod test {
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -908,7 +908,7 @@ ENGINE=mito",
),
];
for (sql, expected) in cases {
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
match &result[0] {
Statement::CreateTable(c) => {
let expr = expr_factory::create_to_expr(c, QueryContext::arc()).unwrap();

View File

@@ -26,7 +26,7 @@ use promql_parser::parser::ast::{Extension as NodeExtension, ExtensionExpr};
use promql_parser::parser::Expr::Extension;
use promql_parser::parser::{EvalStmt, Expr, ValueType};
use snafu::ResultExt;
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -108,7 +108,7 @@ pub struct QueryLanguageParser {}
impl QueryLanguageParser {
pub fn parse_sql(sql: &str) -> Result<QueryStatement> {
let _timer = timer!(METRIC_PARSE_SQL_ELAPSED);
let mut statement = ParserContext::create_with_dialect(sql, &GenericDialect {})
let mut statement = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.map_err(BoxedError::new)
.context(QueryParseSnafu {
query: sql.to_string(),

View File

@@ -20,7 +20,7 @@ use sql::ast::{
ColumnDef, ColumnOption, ColumnOptionDef, Expr, ObjectName, SqlOption, TableConstraint,
Value as SqlValue,
};
use sql::dialect::GenericDialect;
use sql::dialect::GreptimeDbDialect;
use sql::parser::ParserContext;
use sql::statements::create::{CreateTable, TIME_INDEX};
use sql::statements::{self};
@@ -108,7 +108,7 @@ fn create_column_def(column_schema: &ColumnSchema) -> Result<ColumnDef> {
.with_context(|_| ConvertSqlValueSnafu { value: v.clone() })?,
),
ColumnDefaultConstraint::Function(expr) => {
ParserContext::parse_function(expr, &GenericDialect {}).context(SqlSnafu)?
ParserContext::parse_function(expr, &GreptimeDbDialect {}).context(SqlSnafu)?
}
};

View File

@@ -32,9 +32,9 @@ use opensrv_mysql::{
use parking_lot::RwLock;
use rand::RngCore;
use session::context::Channel;
use session::Session;
use session::{Session, SessionRef};
use snafu::ensure;
use sql::dialect::GenericDialect;
use sql::dialect::MySqlDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use tokio::io::AsyncWrite;
@@ -48,7 +48,7 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef;
pub struct MysqlInstanceShim {
query_handler: ServerSqlQueryHandlerRef,
salt: [u8; 20],
session: Arc<Session>,
session: SessionRef,
user_provider: Option<UserProviderRef>,
// TODO(SSebo): use something like moka to achieve TTL or LRU
prepared_stmts: Arc<RwLock<HashMap<u32, String>>>,
@@ -77,7 +77,7 @@ impl MysqlInstanceShim {
MysqlInstanceShim {
query_handler,
salt: scramble,
session: Arc::new(Session::new(client_addr, Channel::Mysql)),
session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)),
user_provider,
prepared_stmts: Default::default(),
prepared_stmts_counter: AtomicU32::new(1),
@@ -140,9 +140,13 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
let username = String::from_utf8_lossy(username);
let mut user_info = None;
let addr = self.session.conn_info().client_host.to_string();
let addr = self
.session
.conn_info()
.client_addr
.map(|addr| addr.to_string());
if let Some(user_provider) = &self.user_provider {
let user_id = Identity::UserId(&username, Some(addr.as_str()));
let user_id = Identity::UserId(&username, addr.as_deref());
let password = match auth_plugin {
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
@@ -331,7 +335,7 @@ fn format_duration(duration: Duration) -> String {
}
async fn validate_query(query: &str) -> Result<Statement> {
let statement = ParserContext::create_with_dialect(query, &GenericDialect {});
let statement = ParserContext::create_with_dialect(query, &MySqlDialect {});
let mut statement = statement.map_err(|e| {
InvalidPrepareStatementSnafu {
err_msg: e.to_string(),

View File

@@ -31,7 +31,8 @@ use pgwire::api::auth::ServerParameterProvider;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, MakeHandler};
pub use server::PostgresServer;
use session::context::{QueryContext, QueryContextRef};
use session::context::Channel;
use session::Session;
use sql::statements::statement::Statement;
use self::auth_handler::PgLoginVerifier;
@@ -73,7 +74,7 @@ pub struct PostgresServerHandler {
force_tls: bool,
param_provider: Arc<GreptimeDBStartupParameters>,
query_ctx: QueryContextRef,
session: Session,
portal_store: Arc<MemPortalStore<(Statement, String)>>,
query_parser: Arc<POCQueryParser>,
}
@@ -90,18 +91,18 @@ pub(crate) struct MakePostgresServerHandler {
}
impl MakeHandler for MakePostgresServerHandler {
type Handler = Arc<PostgresServerHandler>;
type Handler = PostgresServerHandler;
fn make(&self) -> Self::Handler {
Arc::new(PostgresServerHandler {
PostgresServerHandler {
query_handler: self.query_handler.clone(),
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
force_tls: self.force_tls,
param_provider: self.param_provider.clone(),
query_ctx: QueryContext::arc(),
session: Session::new(None, Channel::Postgres),
portal_store: Arc::new(MemPortalStore::new()),
query_parser: self.query_parser.clone(),
})
}
}
}

View File

@@ -24,7 +24,8 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use session::context::QueryContextRef;
use session::context::UserInfo;
use session::Session;
use super::PostgresServerHandler;
use crate::auth::{Identity, Password, UserProviderRef};
@@ -112,15 +113,19 @@ impl PgLoginVerifier {
}
}
fn set_query_context_from_client_info<C>(client: &C, query_context: QueryContextRef)
fn set_client_info<C>(client: &C, session: &Session)
where
C: ClientInfo,
{
let ctx = session.context();
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
query_context.set_current_catalog(current_catalog);
ctx.set_current_catalog(current_catalog);
}
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
query_context.set_current_schema(current_schema);
ctx.set_current_schema(current_schema);
}
if let Some(username) = client.metadata().get(super::METADATA_USER) {
session.set_user_info(UserInfo::new(username));
}
}
@@ -170,7 +175,7 @@ impl StartupHandler for PostgresServerHandler {
))
.await?;
} else {
set_query_context_from_client_info(client, self.query_ctx.clone());
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
}
@@ -193,7 +198,7 @@ impl StartupHandler for PostgresServerHandler {
)
.await;
}
set_query_context_from_client_info(client, self.query_ctx.clone());
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
_ => {}

View File

@@ -33,7 +33,7 @@ use pgwire::api::stmt::QueryParser;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use sql::dialect::GenericDialect;
use sql::dialect::PostgreSqlDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -55,13 +55,13 @@ impl SimpleQueryHandler for PostgresServerHandler {
),
(
crate::metrics::METRIC_DB_LABEL,
self.query_ctx.get_db_string()
self.session.context().get_db_string()
)
]
);
let outputs = self
.query_handler
.do_query(query, self.query_ctx.clone())
.do_query(query, self.session.context())
.await;
let mut results = Vec::with_capacity(outputs.len());
@@ -260,7 +260,7 @@ impl QueryParser for POCQueryParser {
fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult<Self::Statement> {
increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT);
let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {})
let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {})
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
if stmts.len() != 1 {
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
@@ -361,7 +361,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
),
(
crate::metrics::METRIC_DB_LABEL,
self.query_ctx.get_db_string()
self.session.context().get_db_string()
)
]
);
@@ -376,7 +376,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
let output = self
.query_handler
.do_query(&sql, self.query_ctx.clone())
.do_query(&sql, self.session.context())
.await
.remove(0);
@@ -407,7 +407,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
if let Some(schema) = self
.query_handler
.do_describe(stmt.clone(), self.query_ctx.clone())
.do_describe(stmt.clone(), self.session.context())
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{

View File

@@ -73,19 +73,22 @@ impl PostgresServer {
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let tls_acceptor = tls_acceptor.clone();
let handler = handler.make();
let mut handler = handler.make();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
match io_stream.peer_addr() {
Ok(addr) => debug!("PostgreSQL client coming from {}", addr),
Ok(addr) => {
handler.session.mut_conn_info().client_addr = Some(addr);
debug!("PostgreSQL client coming from {}", addr)
}
Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e),
}
io_runtime.spawn(async move {
increment_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0);
let handler = Arc::new(handler);
let r = process_socket(
io_stream,
tls_acceptor.clone(),

View File

@@ -9,3 +9,4 @@ arc-swap = "1.5"
common-catalog = { path = "../common/catalog" }
common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
sql = { path = "../sql" }

View File

@@ -21,6 +21,7 @@ use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_telemetry::debug;
use common_time::TimeZone;
use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
pub type QueryContextRef = Arc<QueryContext>;
pub type ConnInfoRef = Arc<ConnInfo>;
@@ -30,6 +31,7 @@ pub struct QueryContext {
current_catalog: ArcSwap<String>,
current_schema: ArcSwap<String>,
time_zone: ArcSwap<Option<TimeZone>>,
sql_dialect: Box<dyn Dialect + Send + Sync>,
}
impl Default for QueryContext {
@@ -59,25 +61,42 @@ impl QueryContext {
current_catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.to_string())),
current_schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.to_string())),
time_zone: ArcSwap::new(Arc::new(None)),
sql_dialect: Box::new(GreptimeDbDialect {}),
}
}
pub fn with(catalog: &str, schema: &str) -> Self {
Self::with_sql_dialect(catalog, schema, Box::new(GreptimeDbDialect {}))
}
pub fn with_sql_dialect(
catalog: &str,
schema: &str,
sql_dialect: Box<dyn Dialect + Send + Sync>,
) -> Self {
Self {
current_catalog: ArcSwap::new(Arc::new(catalog.to_string())),
current_schema: ArcSwap::new(Arc::new(schema.to_string())),
time_zone: ArcSwap::new(Arc::new(None)),
sql_dialect,
}
}
#[inline]
pub fn current_schema(&self) -> String {
self.current_schema.load().as_ref().clone()
}
#[inline]
pub fn current_catalog(&self) -> String {
self.current_catalog.load().as_ref().clone()
}
#[inline]
pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
&*self.sql_dialect
}
pub fn set_current_schema(&self, schema: &str) {
let last = self.current_schema.swap(Arc::new(schema.to_string()));
if schema != last.as_str() {
@@ -142,15 +161,30 @@ impl UserInfo {
}
}
#[derive(Debug)]
pub struct ConnInfo {
pub client_host: SocketAddr,
pub client_addr: Option<SocketAddr>,
pub channel: Channel,
}
impl std::fmt::Display for ConnInfo {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{}[{}]",
self.channel,
self.client_addr
.map(|addr| addr.to_string())
.as_deref()
.unwrap_or("unknown client addr")
)
}
}
impl ConnInfo {
pub fn new(client_host: SocketAddr, channel: Channel) -> Self {
pub fn new(client_addr: Option<SocketAddr>, channel: Channel) -> Self {
Self {
client_host,
client_addr,
channel,
}
}
@@ -158,13 +192,26 @@ impl ConnInfo {
#[derive(Debug, PartialEq)]
pub enum Channel {
Grpc,
Http,
Mysql,
Postgres,
Opentsdb,
Influxdb,
Prometheus,
}
impl Channel {
pub fn dialect(&self) -> Box<dyn Dialect + Send + Sync> {
match self {
Channel::Mysql => Box::new(MySqlDialect {}),
Channel::Postgres => Box::new(PostgreSqlDialect {}),
}
}
}
impl std::fmt::Display for Channel {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Channel::Mysql => write!(f, "mysql"),
Channel::Postgres => write!(f, "postgres"),
}
}
}
#[cfg(test)]
@@ -175,7 +222,7 @@ mod test {
#[test]
fn test_session() {
let session = Session::new("127.0.0.1:9000".parse().unwrap(), Channel::Mysql);
let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), Channel::Mysql);
// test user_info
assert_eq!(session.user_info().username(), "greptime");
session.set_user_info(UserInfo::new("root"));
@@ -183,11 +230,11 @@ mod test {
// test channel
assert_eq!(session.conn_info().channel, Channel::Mysql);
assert_eq!(
session.conn_info().client_host.ip().to_string(),
"127.0.0.1"
);
assert_eq!(session.conn_info().client_host.port(), 9000);
let client_addr = session.conn_info().client_addr.as_ref().unwrap();
assert_eq!(client_addr.ip().to_string(), "127.0.0.1");
assert_eq!(client_addr.port(), 9000);
assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
}
#[test]

View File

@@ -18,33 +18,54 @@ use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwap;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use crate::context::{Channel, ConnInfo, ConnInfoRef, QueryContext, QueryContextRef, UserInfo};
use crate::context::{Channel, ConnInfo, QueryContext, QueryContextRef, UserInfo};
/// Session for persistent connection such as MySQL, PostgreSQL etc.
#[derive(Debug)]
pub struct Session {
query_ctx: QueryContextRef,
user_info: ArcSwap<UserInfo>,
conn_info: ConnInfoRef,
conn_info: ConnInfo,
}
pub type SessionRef = Arc<Session>;
impl Session {
pub fn new(addr: SocketAddr, channel: Channel) -> Self {
pub fn new(addr: Option<SocketAddr>, channel: Channel) -> Self {
Session {
query_ctx: Arc::new(QueryContext::new()),
query_ctx: Arc::new(QueryContext::with_sql_dialect(
DEFAULT_CATALOG_NAME,
DEFAULT_SCHEMA_NAME,
channel.dialect(),
)),
user_info: ArcSwap::new(Arc::new(UserInfo::default())),
conn_info: Arc::new(ConnInfo::new(addr, channel)),
conn_info: ConnInfo::new(addr, channel),
}
}
#[inline]
pub fn context(&self) -> QueryContextRef {
self.query_ctx.clone()
}
pub fn conn_info(&self) -> ConnInfoRef {
self.conn_info.clone()
#[inline]
pub fn conn_info(&self) -> &ConnInfo {
&self.conn_info
}
#[inline]
pub fn mut_conn_info(&mut self) -> &mut ConnInfo {
&mut self.conn_info
}
#[inline]
pub fn user_info(&self) -> Arc<UserInfo> {
self.user_info.load().clone()
}
#[inline]
pub fn set_user_info(&self, user_info: UserInfo) {
self.user_info.store(Arc::new(user_info));
}

View File

@@ -6,7 +6,6 @@ license.workspace = true
[dependencies]
api = { path = "../api" }
catalog = { path = "../catalog" }
common-base = { path = "../common/base" }
common-catalog = { path = "../common/catalog" }
common-datasource = { path = "../common/datasource" }

View File

@@ -12,6 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// todo(hl) wrap sqlparser dialects
pub use sqlparser::dialect::{Dialect, MySqlDialect, PostgreSqlDialect};
pub use sqlparser::dialect::{Dialect, GenericDialect};
/// GreptimeDb dialect
#[derive(Debug, Clone)]
pub struct GreptimeDbDialect {}
impl Dialect for GreptimeDbDialect {
fn is_identifier_start(&self, ch: char) -> bool {
ch.is_alphabetic() || ch == '_' || ch == '#' || ch == '@'
}
fn is_identifier_part(&self, ch: char) -> bool {
ch.is_alphabetic()
|| ch.is_ascii_digit()
|| ch == '@'
|| ch == '$'
|| ch == '#'
|| ch == '_'
}
// Accepts both `identifier` and "identifier".
fn is_delimited_identifier_start(&self, ch: char) -> bool {
ch == '`' || ch == '"'
}
fn supports_filter_during_aggregation(&self) -> bool {
true
}
}

View File

@@ -393,16 +393,16 @@ mod tests {
use sqlparser::ast::{
Ident, ObjectName, Query as SpQuery, Statement as SpStatement, WildcardAdditionalOptions,
};
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
use crate::statements::create::CreateTable;
use crate::statements::sql_data_type_to_concrete_data_type;
#[test]
pub fn test_show_database_all() {
let sql = "SHOW DATABASES";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -417,7 +417,7 @@ mod tests {
#[test]
pub fn test_show_database_like() {
let sql = "SHOW DATABASES LIKE test_database";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -435,7 +435,7 @@ mod tests {
#[test]
pub fn test_show_database_where() {
let sql = "SHOW DATABASES WHERE Database LIKE '%whatever1%' OR Database LIKE '%whatever2%'";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -454,7 +454,7 @@ mod tests {
#[test]
pub fn test_show_tables_all() {
let sql = "SHOW TABLES";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -470,7 +470,7 @@ mod tests {
#[test]
pub fn test_show_tables_like() {
let sql = "SHOW TABLES LIKE test_table";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -486,7 +486,7 @@ mod tests {
);
let sql = "SHOW TABLES in test_db LIKE test_table";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -505,7 +505,7 @@ mod tests {
#[test]
pub fn test_show_tables_where() {
let sql = "SHOW TABLES where name like test_table";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -518,7 +518,7 @@ mod tests {
);
let sql = "SHOW TABLES in test_db where name LIKE test_table";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -534,7 +534,7 @@ mod tests {
#[test]
pub fn test_explain() {
let sql = "EXPLAIN select * from foo";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let stmts = result.unwrap();
assert_eq!(1, stmts.len());
@@ -589,7 +589,7 @@ mod tests {
#[test]
pub fn test_drop_table() {
let sql = "DROP TABLE foo";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let mut stmts = result.unwrap();
assert_eq!(
stmts.pop().unwrap(),
@@ -597,7 +597,7 @@ mod tests {
);
let sql = "DROP TABLE my_schema.foo";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let mut stmts = result.unwrap();
assert_eq!(
stmts.pop().unwrap(),
@@ -608,7 +608,7 @@ mod tests {
);
let sql = "DROP TABLE my_catalog.my_schema.foo";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
let mut stmts = result.unwrap();
assert_eq!(
stmts.pop().unwrap(),
@@ -621,7 +621,7 @@ mod tests {
}
fn test_timestamp_precision(sql: &str, expected_type: ConcreteDataType) {
match ParserContext::create_with_dialect(sql, &GenericDialect {})
match ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.pop()
.unwrap()
@@ -673,7 +673,7 @@ mod tests {
#[test]
fn test_parse_function() {
let expr =
ParserContext::parse_function("current_timestamp()", &GenericDialect {}).unwrap();
ParserContext::parse_function("current_timestamp()", &GreptimeDbDialect {}).unwrap();
assert!(matches!(expr, Expr::Function(_)));
}
}

View File

@@ -79,14 +79,14 @@ mod tests {
use std::assert_matches::assert_matches;
use sqlparser::ast::{ColumnOption, DataType};
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
#[test]
fn test_parse_alter_add_column() {
let sql = "ALTER TABLE my_metric_1 ADD tagk_i STRING Null;";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -116,13 +116,13 @@ mod tests {
#[test]
fn test_parse_alter_drop_column() {
let sql = "ALTER TABLE my_metric_1 DROP a";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err();
assert!(result
.to_string()
.contains("expect keyword COLUMN after ALTER TABLE DROP"));
let sql = "ALTER TABLE my_metric_1 DROP COLUMN a";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -147,13 +147,13 @@ mod tests {
#[test]
fn test_parse_alter_rename_table() {
let sql = "ALTER TABLE test_table table_t";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err();
assert!(result
.to_string()
.contains("expect keyword ADD or DROP or RENAME after ALTER TABLE"));
let sql = "ALTER TABLE test_table RENAME table_t";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);

View File

@@ -139,16 +139,15 @@ mod tests {
use std::assert_matches::assert_matches;
use std::collections::HashMap;
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
#[test]
fn test_parse_copy_table() {
let sql0 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet'";
let sql1 = "COPY catalog0.schema0.tbl TO 'tbl_file.parquet' WITH (FORMAT = 'parquet')";
let result0 = ParserContext::create_with_dialect(sql0, &GenericDialect {}).unwrap();
let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap();
let result0 = ParserContext::create_with_dialect(sql0, &GreptimeDbDialect {}).unwrap();
let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap();
for mut result in vec![result0, result1] {
assert_eq!(1, result.len());
@@ -190,7 +189,7 @@ mod tests {
"COPY catalog0.schema0.tbl FROM 'tbl_file.parquet' WITH (FORMAT = 'parquet')",
]
.iter()
.map(|sql| ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap())
.map(|sql| ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap())
.collect::<Vec<_>>();
for mut result in results {
@@ -249,7 +248,7 @@ mod tests {
for test in tests {
let mut result =
ParserContext::create_with_dialect(test.sql, &GenericDialect {}).unwrap();
ParserContext::create_with_dialect(test.sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -290,7 +289,7 @@ mod tests {
for test in tests {
let mut result =
ParserContext::create_with_dialect(test.sql, &GenericDialect {}).unwrap();
ParserContext::create_with_dialect(test.sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);

View File

@@ -784,9 +784,9 @@ mod tests {
use common_catalog::consts::IMMUTABLE_FILE_ENGINE;
use sqlparser::ast::ColumnOption::NotNull;
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
#[test]
fn test_parse_create_external_table() {
@@ -822,7 +822,8 @@ mod tests {
];
for test in tests {
let stmts = ParserContext::create_with_dialect(test.sql, &GenericDialect {}).unwrap();
let stmts =
ParserContext::create_with_dialect(test.sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
match &stmts[0] {
Statement::CreateExternalTable(c) => {
@@ -852,7 +853,7 @@ mod tests {
("format".to_string(), "csv".to_string()),
]);
let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
match &stmts[0] {
Statement::CreateExternalTable(c) => {
@@ -888,14 +889,14 @@ mod tests {
#[test]
fn test_parse_create_database() {
let sql = "create database";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
.contains("Unexpected token while parsing SQL statement"));
let sql = "create database prometheus";
let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
match &stmts[0] {
@@ -907,7 +908,7 @@ mod tests {
}
let sql = "create database if not exists prometheus";
let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
match &stmts[0] {
@@ -929,7 +930,7 @@ PARTITION BY RANGE COLUMNS(b, a) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_ok());
let sql = r"
@@ -940,7 +941,7 @@ PARTITION BY RANGE COLUMNS(b, x) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -955,7 +956,7 @@ PARTITION BY RANGE COLUMNS(b, a) (
PARTITION r1 VALUES LESS THAN (MAXVALUE, MAXVALUE),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -969,7 +970,7 @@ PARTITION BY RANGE COLUMNS(b, a) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -1010,7 +1011,7 @@ PARTITION BY RANGE COLUMNS(b, a) (
ENGINE=mito",
];
for sql in cases {
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -1025,7 +1026,7 @@ PARTITION BY RANGE COLUMNS(b, a) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, 9999),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -1051,7 +1052,7 @@ PARTITION BY RANGE COLUMNS(idc, host_id) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(result.len(), 1);
match &result[0] {
Statement::CreateTable(c) => {
@@ -1117,7 +1118,7 @@ CREATE TABLE monitor (
PRIMARY KEY (host),
)
ENGINE=mito";
let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap();
let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap();
if let Statement::CreateTable(c) = &result1[0] {
assert_eq!(c.constraints.len(), 2);
@@ -1152,7 +1153,7 @@ CREATE TABLE monitor (
PRIMARY KEY (host),
)
ENGINE=mito";
let result2 = ParserContext::create_with_dialect(sql2, &GenericDialect {}).unwrap();
let result2 = ParserContext::create_with_dialect(sql2, &GreptimeDbDialect {}).unwrap();
assert_eq!(result1, result2);
@@ -1169,7 +1170,7 @@ CREATE TABLE monitor (
)
ENGINE=mito";
let result3 = ParserContext::create_with_dialect(sql3, &GenericDialect {}).unwrap();
let result3 = ParserContext::create_with_dialect(sql3, &GreptimeDbDialect {}).unwrap();
assert_ne!(result1, result3);
@@ -1184,7 +1185,7 @@ CREATE TABLE monitor (
PRIMARY KEY (host),
)
ENGINE=mito";
let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap();
let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap();
if let Statement::CreateTable(c) = &result1[0] {
assert_eq!(c.constraints.len(), 2);
@@ -1220,7 +1221,7 @@ CREATE TABLE monitor (
PRIMARY KEY (host),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(result.len(), 1);
if let Statement::CreateTable(c) = &result[0] {
@@ -1243,7 +1244,7 @@ CREATE TABLE monitor (
)
ENGINE=mito";
let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap();
let result1 = ParserContext::create_with_dialect(sql1, &GreptimeDbDialect {}).unwrap();
assert_eq!(result, result1);
let sql2 = r"
@@ -1258,7 +1259,7 @@ CREATE TABLE monitor (
)
ENGINE=mito";
let result2 = ParserContext::create_with_dialect(sql2, &GenericDialect {}).unwrap();
let result2 = ParserContext::create_with_dialect(sql2, &GreptimeDbDialect {}).unwrap();
assert_eq!(result, result2);
let sql3 = r"
@@ -1273,7 +1274,7 @@ CREATE TABLE monitor (
)
ENGINE=mito";
let result3 = ParserContext::create_with_dialect(sql3, &GenericDialect {});
let result3 = ParserContext::create_with_dialect(sql3, &GreptimeDbDialect {});
assert!(result3.is_err());
let sql4 = r"
@@ -1288,7 +1289,7 @@ CREATE TABLE monitor (
)
ENGINE=mito";
let result4 = ParserContext::create_with_dialect(sql4, &GenericDialect {});
let result4 = ParserContext::create_with_dialect(sql4, &GreptimeDbDialect {});
assert!(result4.is_err());
let sql = r"
@@ -1303,7 +1304,7 @@ CREATE TABLE monitor (
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
if let Statement::CreateTable(c) = &result[0] {
let tc = c.constraints[0].clone();
@@ -1339,7 +1340,7 @@ PARTITION RANGE COLUMNS(b, a) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -1353,7 +1354,7 @@ PARTITION BY RANGE COLUMNS(b, a) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALUE),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -1367,11 +1368,11 @@ PARTITION BY RANGE COLUMNS(b, a) (
PARTITION r3 VALUES LESS THAN (MAXVALUE, MAXVALU),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
.contains("Please provide an extra partition that is bounded by 'MAXVALUE'."));
.contains("Expected a concrete value, found: MAXVALU"));
}
fn assert_column_def(column: &ColumnDef, name: &str, data_type: &str) {
@@ -1390,7 +1391,7 @@ ENGINE=mito";
PRIMARY KEY(ts, host)) engine=mito
with(regions=1);
";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
match &result[0] {
Statement::CreateTable(c) => {
@@ -1438,7 +1439,7 @@ ENGINE=mito";
PRIMARY KEY(ts, host)) engine=mito
with(regions=1);
";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_err());
assert_matches!(result, Err(crate::error::Error::InvalidTimeIndex { .. }));
}
@@ -1455,7 +1456,7 @@ ENGINE=mito";
PRIMARY KEY(ts, host)) engine=mito
with(regions=1);
";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_err());
assert_matches!(result, Err(crate::error::Error::InvalidColumnOption { .. }));
@@ -1469,7 +1470,7 @@ ENGINE=mito";
PRIMARY KEY(ts, host)) engine=mito
with(regions=1);
";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_err());
assert_matches!(result, Err(crate::error::Error::InvalidTimeIndex { .. }));
}
@@ -1477,7 +1478,7 @@ ENGINE=mito";
#[test]
fn test_invalid_column_name() {
let sql = "create table foo(user string, i bigint time index)";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result
.unwrap_err()
.to_string()
@@ -1487,7 +1488,7 @@ ENGINE=mito";
let sql = r#"
create table foo("user" string, i bigint time index)
"#;
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_ok());
}
}

View File

@@ -46,14 +46,13 @@ impl<'a> ParserContext<'a> {
mod tests {
use std::assert_matches::assert_matches;
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
#[test]
pub fn test_parse_insert() {
let sql = r"delete from my_table where k1 = xxx and k2 = xxx and timestamp = xxx;";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
assert_matches!(result[0], Statement::Delete { .. })
}
@@ -61,7 +60,7 @@ mod tests {
#[test]
pub fn test_parse_invalid_insert() {
let sql = r"delete my_table where "; // intentionally a bad sql
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_err(), "result is: {result:?}");
}
}

View File

@@ -46,9 +46,8 @@ impl<'a> ParserContext<'a> {
mod tests {
use std::assert_matches::assert_matches;
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
#[test]
pub fn test_parse_insert() {
@@ -56,7 +55,7 @@ mod tests {
'test1',1,'true',
'test2',2,'false')
";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
assert_matches!(result[0], Statement::Insert { .. })
}
@@ -64,7 +63,7 @@ mod tests {
#[test]
pub fn test_parse_invalid_insert() {
let sql = r"INSERT INTO table_1 VALUES ("; // intentionally a bad sql
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_err(), "result is: {result:?}");
}
}

View File

@@ -33,8 +33,7 @@ impl<'a> ParserContext<'a> {
#[cfg(test)]
mod tests {
use sqlparser::dialect::GenericDialect;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParserContext;
#[test]
@@ -44,13 +43,13 @@ mod tests {
WHERE a > b AND b < 100 \
ORDER BY a DESC, b";
let _ = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let _ = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
}
#[test]
pub fn test_parse_invalid_query() {
let sql = "SELECT * FROM table_1 WHERE";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {});
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {});
assert!(result.is_err());
assert!(result
.unwrap_err()

View File

@@ -166,14 +166,13 @@ impl<'a> ParserContext<'a> {
#[cfg(test)]
mod tests {
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
#[test]
fn test_parse_tql_eval() {
let sql = "TQL EVAL (1676887657, 1676887659, '1m') http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -189,7 +188,7 @@ mod tests {
let sql = "TQL EVAL (1676887657.1, 1676887659.5, 30.3) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -205,7 +204,7 @@ mod tests {
let sql = "TQL EVALUATE (1676887657.1, 1676887659.5, 30.3) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement2 = result.remove(0);
@@ -213,7 +212,7 @@ mod tests {
let sql = "tql eval ('2015-07-01T20:10:30.781Z', '2015-07-01T20:11:00.781Z', '30s') http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -232,7 +231,7 @@ mod tests {
fn test_parse_tql_explain() {
let sql = "TQL EXPLAIN http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -248,7 +247,7 @@ mod tests {
let sql = "TQL EXPLAIN (20,100,10) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
@@ -266,7 +265,7 @@ mod tests {
#[test]
fn test_parse_tql_analyze() {
let sql = "TQL ANALYZE (1676887657.1, 1676887659.5, 30.3) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let mut result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let mut result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
let statement = result.remove(0);
match statement {
@@ -284,12 +283,12 @@ mod tests {
fn test_parse_tql_error() {
// Invalid duration
let sql = "TQL EVAL (1676887657, 1676887659, 1m) http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err();
assert!(result.to_string().contains("Expected ), found: m"));
// missing end
let sql = "TQL EVAL (1676887657, '1m') http_requests_total{environment=~'staging|testing|development',method!='GET'} @ 1609746000 offset 5m";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err();
assert!(result.to_string().contains("Expected ,, found: )"));
}
}

View File

@@ -206,8 +206,7 @@ pub struct CreateExternalTable {
#[cfg(test)]
mod tests {
use sqlparser::dialect::GenericDialect;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParserContext;
use crate::statements::statement::Statement;
@@ -229,7 +228,7 @@ mod tests {
engine=mito
with(regions=1, ttl='7d');
";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, result.len());
match &result[0] {
@@ -259,7 +258,7 @@ WITH(
);
let new_result =
ParserContext::create_with_dialect(&new_sql, &GenericDialect {}).unwrap();
ParserContext::create_with_dialect(&new_sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(result, new_result);
}
_ => unreachable!(),

View File

@@ -35,8 +35,7 @@ impl DescribeTable {
mod tests {
use std::assert_matches::assert_matches;
use sqlparser::dialect::GenericDialect;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParserContext;
use crate::statements::statement::Statement;
@@ -44,7 +43,7 @@ mod tests {
pub fn test_describe_table() {
let sql = "DESCRIBE TABLE test";
let stmts: Vec<Statement> =
ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
assert_matches!(&stmts[0], Statement::DescribeTable { .. });
match &stmts[0] {
@@ -61,7 +60,7 @@ mod tests {
pub fn test_describe_schema_table() {
let sql = "DESCRIBE TABLE test_schema.test";
let stmts: Vec<Statement> =
ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
assert_matches!(&stmts[0], Statement::DescribeTable { .. });
match &stmts[0] {
@@ -78,7 +77,7 @@ mod tests {
pub fn test_describe_catalog_schema_table() {
let sql = "DESCRIBE TABLE test_catalog.test_schema.test";
let stmts: Vec<Statement> =
ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
assert_matches!(&stmts[0], Statement::DescribeTable { .. });
match &stmts[0] {
@@ -94,6 +93,6 @@ mod tests {
#[test]
pub fn test_describe_missing_table_name() {
let sql = "DESCRIBE TABLE";
ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err();
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err();
}
}

View File

@@ -136,9 +136,8 @@ impl TryFrom<Statement> for Insert {
#[cfg(test)]
mod tests {
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParserContext;
use crate::statements::statement::Statement;
@@ -146,7 +145,7 @@ mod tests {
fn test_insert_value_with_unary_op() {
// insert "-1"
let sql = "INSERT INTO my_table VALUES(-1)";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.remove(0);
match stmt {
@@ -159,7 +158,7 @@ mod tests {
// insert "+1"
let sql = "INSERT INTO my_table VALUES(+1)";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.remove(0);
match stmt {
@@ -175,7 +174,7 @@ mod tests {
fn test_insert_value_with_default() {
// insert "default"
let sql = "INSERT INTO my_table VALUES(default)";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.remove(0);
match stmt {
@@ -191,7 +190,7 @@ mod tests {
fn test_insert_value_with_default_uppercase() {
// insert "DEFAULT"
let sql = "INSERT INTO my_table VALUES(DEFAULT)";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.remove(0);
match stmt {
@@ -207,7 +206,7 @@ mod tests {
fn test_insert_value_with_quoted_string() {
// insert "'default'"
let sql = "INSERT INTO my_table VALUES('default')";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.remove(0);
match stmt {
@@ -225,7 +224,7 @@ mod tests {
#[test]
fn test_insert_select() {
let sql = "INSERT INTO my_table select * from other_table";
let stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
let stmt = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.remove(0);
match stmt {

View File

@@ -74,14 +74,13 @@ impl fmt::Display for Query {
#[cfg(test)]
mod test {
use sqlparser::dialect::GenericDialect;
use super::Query;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParserContext;
use crate::statements::statement::Statement;
fn create_query(sql: &str) -> Option<Box<Query>> {
match ParserContext::create_with_dialect(sql, &GenericDialect {})
match ParserContext::create_with_dialect(sql, &GreptimeDbDialect {})
.unwrap()
.remove(0)
{

View File

@@ -65,9 +65,9 @@ mod tests {
use std::assert_matches::assert_matches;
use sqlparser::ast::UnaryOperator;
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::dialect::GreptimeDbDialect;
use crate::parser::ParserContext;
use crate::statements::statement::Statement;
@@ -102,7 +102,7 @@ mod tests {
#[test]
pub fn test_show_database() {
let sql = "SHOW DATABASES";
let stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
let stmts = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
assert_matches!(&stmts[0], Statement::ShowDatabases { .. });
match &stmts[0] {
@@ -119,7 +119,7 @@ mod tests {
pub fn test_show_create_table() {
let sql = "SHOW CREATE TABLE test";
let stmts: Vec<Statement> =
ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap();
assert_eq!(1, stmts.len());
assert_matches!(&stmts[0], Statement::ShowCreateTable { .. });
match &stmts[0] {
@@ -135,6 +135,6 @@ mod tests {
#[test]
pub fn test_show_create_missing_table_name() {
let sql = "SHOW CREATE TABLE";
ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap_err();
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}).unwrap_err();
}
}