mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-23 00:10:38 +00:00
feat: sql dialect for different protocols (#1631)
* feat: add SqlDialect to query context * feat: use session in postgrel handlers * chore: refactor sql dialect * feat: use different dialects for different sql protocols * feat: adds GreptimeDbDialect * refactor: replace GenericDialect with GreptimeDbDialect * feat: save user info to session * fix: compile error * fix: test
This commit is contained in:
@@ -32,9 +32,9 @@ use opensrv_mysql::{
|
||||
use parking_lot::RwLock;
|
||||
use rand::RngCore;
|
||||
use session::context::Channel;
|
||||
use session::Session;
|
||||
use session::{Session, SessionRef};
|
||||
use snafu::ensure;
|
||||
use sql::dialect::GenericDialect;
|
||||
use sql::dialect::MySqlDialect;
|
||||
use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
use tokio::io::AsyncWrite;
|
||||
@@ -48,7 +48,7 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
pub struct MysqlInstanceShim {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
salt: [u8; 20],
|
||||
session: Arc<Session>,
|
||||
session: SessionRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
// TODO(SSebo): use something like moka to achieve TTL or LRU
|
||||
prepared_stmts: Arc<RwLock<HashMap<u32, String>>>,
|
||||
@@ -77,7 +77,7 @@ impl MysqlInstanceShim {
|
||||
MysqlInstanceShim {
|
||||
query_handler,
|
||||
salt: scramble,
|
||||
session: Arc::new(Session::new(client_addr, Channel::Mysql)),
|
||||
session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)),
|
||||
user_provider,
|
||||
prepared_stmts: Default::default(),
|
||||
prepared_stmts_counter: AtomicU32::new(1),
|
||||
@@ -140,9 +140,13 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
let username = String::from_utf8_lossy(username);
|
||||
|
||||
let mut user_info = None;
|
||||
let addr = self.session.conn_info().client_host.to_string();
|
||||
let addr = self
|
||||
.session
|
||||
.conn_info()
|
||||
.client_addr
|
||||
.map(|addr| addr.to_string());
|
||||
if let Some(user_provider) = &self.user_provider {
|
||||
let user_id = Identity::UserId(&username, Some(addr.as_str()));
|
||||
let user_id = Identity::UserId(&username, addr.as_deref());
|
||||
|
||||
let password = match auth_plugin {
|
||||
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
|
||||
@@ -331,7 +335,7 @@ fn format_duration(duration: Duration) -> String {
|
||||
}
|
||||
|
||||
async fn validate_query(query: &str) -> Result<Statement> {
|
||||
let statement = ParserContext::create_with_dialect(query, &GenericDialect {});
|
||||
let statement = ParserContext::create_with_dialect(query, &MySqlDialect {});
|
||||
let mut statement = statement.map_err(|e| {
|
||||
InvalidPrepareStatementSnafu {
|
||||
err_msg: e.to_string(),
|
||||
|
||||
@@ -31,7 +31,8 @@ use pgwire::api::auth::ServerParameterProvider;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, MakeHandler};
|
||||
pub use server::PostgresServer;
|
||||
use session::context::{QueryContext, QueryContextRef};
|
||||
use session::context::Channel;
|
||||
use session::Session;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use self::auth_handler::PgLoginVerifier;
|
||||
@@ -73,7 +74,7 @@ pub struct PostgresServerHandler {
|
||||
force_tls: bool,
|
||||
param_provider: Arc<GreptimeDBStartupParameters>,
|
||||
|
||||
query_ctx: QueryContextRef,
|
||||
session: Session,
|
||||
portal_store: Arc<MemPortalStore<(Statement, String)>>,
|
||||
query_parser: Arc<POCQueryParser>,
|
||||
}
|
||||
@@ -90,18 +91,18 @@ pub(crate) struct MakePostgresServerHandler {
|
||||
}
|
||||
|
||||
impl MakeHandler for MakePostgresServerHandler {
|
||||
type Handler = Arc<PostgresServerHandler>;
|
||||
type Handler = PostgresServerHandler;
|
||||
|
||||
fn make(&self) -> Self::Handler {
|
||||
Arc::new(PostgresServerHandler {
|
||||
PostgresServerHandler {
|
||||
query_handler: self.query_handler.clone(),
|
||||
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
|
||||
force_tls: self.force_tls,
|
||||
param_provider: self.param_provider.clone(),
|
||||
|
||||
query_ctx: QueryContext::arc(),
|
||||
session: Session::new(None, Channel::Postgres),
|
||||
portal_store: Arc::new(MemPortalStore::new()),
|
||||
query_parser: self.query_parser.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,8 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::response::ErrorResponse;
|
||||
use pgwire::messages::startup::Authentication;
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
use session::context::QueryContextRef;
|
||||
use session::context::UserInfo;
|
||||
use session::Session;
|
||||
|
||||
use super::PostgresServerHandler;
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
@@ -112,15 +113,19 @@ impl PgLoginVerifier {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_query_context_from_client_info<C>(client: &C, query_context: QueryContextRef)
|
||||
fn set_client_info<C>(client: &C, session: &Session)
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let ctx = session.context();
|
||||
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
|
||||
query_context.set_current_catalog(current_catalog);
|
||||
ctx.set_current_catalog(current_catalog);
|
||||
}
|
||||
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
|
||||
query_context.set_current_schema(current_schema);
|
||||
ctx.set_current_schema(current_schema);
|
||||
}
|
||||
if let Some(username) = client.metadata().get(super::METADATA_USER) {
|
||||
session.set_user_info(UserInfo::new(username));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +175,7 @@ impl StartupHandler for PostgresServerHandler {
|
||||
))
|
||||
.await?;
|
||||
} else {
|
||||
set_query_context_from_client_info(client, self.query_ctx.clone());
|
||||
set_client_info(client, &self.session);
|
||||
auth::finish_authentication(client, self.param_provider.as_ref()).await;
|
||||
}
|
||||
}
|
||||
@@ -193,7 +198,7 @@ impl StartupHandler for PostgresServerHandler {
|
||||
)
|
||||
.await;
|
||||
}
|
||||
set_query_context_from_client_info(client, self.query_ctx.clone());
|
||||
set_client_info(client, &self.session);
|
||||
auth::finish_authentication(client, self.param_provider.as_ref()).await;
|
||||
}
|
||||
_ => {}
|
||||
|
||||
@@ -33,7 +33,7 @@ use pgwire::api::stmt::QueryParser;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use sql::dialect::GenericDialect;
|
||||
use sql::dialect::PostgreSqlDialect;
|
||||
use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
@@ -55,13 +55,13 @@ impl SimpleQueryHandler for PostgresServerHandler {
|
||||
),
|
||||
(
|
||||
crate::metrics::METRIC_DB_LABEL,
|
||||
self.query_ctx.get_db_string()
|
||||
self.session.context().get_db_string()
|
||||
)
|
||||
]
|
||||
);
|
||||
let outputs = self
|
||||
.query_handler
|
||||
.do_query(query, self.query_ctx.clone())
|
||||
.do_query(query, self.session.context())
|
||||
.await;
|
||||
|
||||
let mut results = Vec::with_capacity(outputs.len());
|
||||
@@ -260,7 +260,7 @@ impl QueryParser for POCQueryParser {
|
||||
|
||||
fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult<Self::Statement> {
|
||||
increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT);
|
||||
let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {})
|
||||
let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {})
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
if stmts.len() != 1 {
|
||||
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
|
||||
@@ -361,7 +361,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
),
|
||||
(
|
||||
crate::metrics::METRIC_DB_LABEL,
|
||||
self.query_ctx.get_db_string()
|
||||
self.session.context().get_db_string()
|
||||
)
|
||||
]
|
||||
);
|
||||
@@ -376,7 +376,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
|
||||
let output = self
|
||||
.query_handler
|
||||
.do_query(&sql, self.query_ctx.clone())
|
||||
.do_query(&sql, self.session.context())
|
||||
.await
|
||||
.remove(0);
|
||||
|
||||
@@ -407,7 +407,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
|
||||
if let Some(schema) = self
|
||||
.query_handler
|
||||
.do_describe(stmt.clone(), self.query_ctx.clone())
|
||||
.do_describe(stmt.clone(), self.session.context())
|
||||
.await
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
|
||||
{
|
||||
|
||||
@@ -73,19 +73,22 @@ impl PostgresServer {
|
||||
accepting_stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let handler = handler.make();
|
||||
|
||||
let mut handler = handler.make();
|
||||
async move {
|
||||
match tcp_stream {
|
||||
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
|
||||
Ok(io_stream) => {
|
||||
match io_stream.peer_addr() {
|
||||
Ok(addr) => debug!("PostgreSQL client coming from {}", addr),
|
||||
Ok(addr) => {
|
||||
handler.session.mut_conn_info().client_addr = Some(addr);
|
||||
debug!("PostgreSQL client coming from {}", addr)
|
||||
}
|
||||
Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e),
|
||||
}
|
||||
|
||||
io_runtime.spawn(async move {
|
||||
increment_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0);
|
||||
let handler = Arc::new(handler);
|
||||
let r = process_socket(
|
||||
io_stream,
|
||||
tls_acceptor.clone(),
|
||||
|
||||
Reference in New Issue
Block a user