feat: sql dialect for different protocols (#1631)

* feat: add SqlDialect to query context

* feat: use session in postgrel handlers

* chore: refactor sql dialect

* feat: use different dialects for different sql protocols

* feat: adds GreptimeDbDialect

* refactor: replace GenericDialect with GreptimeDbDialect

* feat: save user info to session

* fix: compile error

* fix: test
This commit is contained in:
dennis zhuang
2023-05-30 09:52:35 +08:00
committed by GitHub
parent 563ce59071
commit ab5dfd31ec
31 changed files with 285 additions and 185 deletions

View File

@@ -32,9 +32,9 @@ use opensrv_mysql::{
use parking_lot::RwLock;
use rand::RngCore;
use session::context::Channel;
use session::Session;
use session::{Session, SessionRef};
use snafu::ensure;
use sql::dialect::GenericDialect;
use sql::dialect::MySqlDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use tokio::io::AsyncWrite;
@@ -48,7 +48,7 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef;
pub struct MysqlInstanceShim {
query_handler: ServerSqlQueryHandlerRef,
salt: [u8; 20],
session: Arc<Session>,
session: SessionRef,
user_provider: Option<UserProviderRef>,
// TODO(SSebo): use something like moka to achieve TTL or LRU
prepared_stmts: Arc<RwLock<HashMap<u32, String>>>,
@@ -77,7 +77,7 @@ impl MysqlInstanceShim {
MysqlInstanceShim {
query_handler,
salt: scramble,
session: Arc::new(Session::new(client_addr, Channel::Mysql)),
session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)),
user_provider,
prepared_stmts: Default::default(),
prepared_stmts_counter: AtomicU32::new(1),
@@ -140,9 +140,13 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
let username = String::from_utf8_lossy(username);
let mut user_info = None;
let addr = self.session.conn_info().client_host.to_string();
let addr = self
.session
.conn_info()
.client_addr
.map(|addr| addr.to_string());
if let Some(user_provider) = &self.user_provider {
let user_id = Identity::UserId(&username, Some(addr.as_str()));
let user_id = Identity::UserId(&username, addr.as_deref());
let password = match auth_plugin {
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
@@ -331,7 +335,7 @@ fn format_duration(duration: Duration) -> String {
}
async fn validate_query(query: &str) -> Result<Statement> {
let statement = ParserContext::create_with_dialect(query, &GenericDialect {});
let statement = ParserContext::create_with_dialect(query, &MySqlDialect {});
let mut statement = statement.map_err(|e| {
InvalidPrepareStatementSnafu {
err_msg: e.to_string(),

View File

@@ -31,7 +31,8 @@ use pgwire::api::auth::ServerParameterProvider;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, MakeHandler};
pub use server::PostgresServer;
use session::context::{QueryContext, QueryContextRef};
use session::context::Channel;
use session::Session;
use sql::statements::statement::Statement;
use self::auth_handler::PgLoginVerifier;
@@ -73,7 +74,7 @@ pub struct PostgresServerHandler {
force_tls: bool,
param_provider: Arc<GreptimeDBStartupParameters>,
query_ctx: QueryContextRef,
session: Session,
portal_store: Arc<MemPortalStore<(Statement, String)>>,
query_parser: Arc<POCQueryParser>,
}
@@ -90,18 +91,18 @@ pub(crate) struct MakePostgresServerHandler {
}
impl MakeHandler for MakePostgresServerHandler {
type Handler = Arc<PostgresServerHandler>;
type Handler = PostgresServerHandler;
fn make(&self) -> Self::Handler {
Arc::new(PostgresServerHandler {
PostgresServerHandler {
query_handler: self.query_handler.clone(),
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
force_tls: self.force_tls,
param_provider: self.param_provider.clone(),
query_ctx: QueryContext::arc(),
session: Session::new(None, Channel::Postgres),
portal_store: Arc::new(MemPortalStore::new()),
query_parser: self.query_parser.clone(),
})
}
}
}

View File

@@ -24,7 +24,8 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use session::context::QueryContextRef;
use session::context::UserInfo;
use session::Session;
use super::PostgresServerHandler;
use crate::auth::{Identity, Password, UserProviderRef};
@@ -112,15 +113,19 @@ impl PgLoginVerifier {
}
}
fn set_query_context_from_client_info<C>(client: &C, query_context: QueryContextRef)
fn set_client_info<C>(client: &C, session: &Session)
where
C: ClientInfo,
{
let ctx = session.context();
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
query_context.set_current_catalog(current_catalog);
ctx.set_current_catalog(current_catalog);
}
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
query_context.set_current_schema(current_schema);
ctx.set_current_schema(current_schema);
}
if let Some(username) = client.metadata().get(super::METADATA_USER) {
session.set_user_info(UserInfo::new(username));
}
}
@@ -170,7 +175,7 @@ impl StartupHandler for PostgresServerHandler {
))
.await?;
} else {
set_query_context_from_client_info(client, self.query_ctx.clone());
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
}
@@ -193,7 +198,7 @@ impl StartupHandler for PostgresServerHandler {
)
.await;
}
set_query_context_from_client_info(client, self.query_ctx.clone());
set_client_info(client, &self.session);
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
_ => {}

View File

@@ -33,7 +33,7 @@ use pgwire::api::stmt::QueryParser;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use sql::dialect::GenericDialect;
use sql::dialect::PostgreSqlDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
@@ -55,13 +55,13 @@ impl SimpleQueryHandler for PostgresServerHandler {
),
(
crate::metrics::METRIC_DB_LABEL,
self.query_ctx.get_db_string()
self.session.context().get_db_string()
)
]
);
let outputs = self
.query_handler
.do_query(query, self.query_ctx.clone())
.do_query(query, self.session.context())
.await;
let mut results = Vec::with_capacity(outputs.len());
@@ -260,7 +260,7 @@ impl QueryParser for POCQueryParser {
fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult<Self::Statement> {
increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT);
let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {})
let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {})
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
if stmts.len() != 1 {
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
@@ -361,7 +361,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
),
(
crate::metrics::METRIC_DB_LABEL,
self.query_ctx.get_db_string()
self.session.context().get_db_string()
)
]
);
@@ -376,7 +376,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
let output = self
.query_handler
.do_query(&sql, self.query_ctx.clone())
.do_query(&sql, self.session.context())
.await
.remove(0);
@@ -407,7 +407,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
if let Some(schema) = self
.query_handler
.do_describe(stmt.clone(), self.query_ctx.clone())
.do_describe(stmt.clone(), self.session.context())
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{

View File

@@ -73,19 +73,22 @@ impl PostgresServer {
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let tls_acceptor = tls_acceptor.clone();
let handler = handler.make();
let mut handler = handler.make();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
match io_stream.peer_addr() {
Ok(addr) => debug!("PostgreSQL client coming from {}", addr),
Ok(addr) => {
handler.session.mut_conn_info().client_addr = Some(addr);
debug!("PostgreSQL client coming from {}", addr)
}
Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e),
}
io_runtime.spawn(async move {
increment_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0);
let handler = Arc::new(handler);
let r = process_socket(
io_stream,
tls_acceptor.clone(),