From e0aecc920925328cff6315efd5fc1921f46617d0 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 21 Jul 2023 11:50:32 +0800 Subject: [PATCH] refactor: improve semantics of session and query context (#2009) --- src/servers/src/mysql/handler.rs | 81 ++++++++++++------------ src/servers/src/postgres/auth_handler.rs | 5 +- src/servers/src/postgres/handler.rs | 29 +++------ src/session/src/lib.rs | 33 +++++++--- 4 files changed, 77 insertions(+), 71 deletions(-) diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index acfaedd0ee..e39b3bd143 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -32,7 +32,7 @@ use parking_lot::RwLock; use query::plan::LogicalPlan; use query::query_engine::DescribeResult; use rand::RngCore; -use session::context::Channel; +use session::context::{Channel, QueryContextRef}; use session::{Session, SessionRef}; use snafu::{ensure, ResultExt}; use sql::dialect::MySqlDialect; @@ -89,18 +89,16 @@ impl MysqlInstanceShim { } } - async fn do_query(&self, query: &str) -> Vec> { + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { trace!("Start executing query: '{}'", query); let start = Instant::now(); - let output = - if let Some(output) = crate::mysql::federated::check(query, self.session.context()) { - vec![Ok(output)] - } else { - self.query_handler - .do_query(query, self.session.context()) - .await - }; + let output = if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone()) + { + vec![Ok(output)] + } else { + self.query_handler.do_query(query, query_ctx).await + }; trace!( "Finished executing query: '{}', total time costs in microseconds: {}", @@ -111,21 +109,26 @@ impl MysqlInstanceShim { } /// Execute the logical plan and return the output - async fn do_exec_plan(&self, query: &str, plan: LogicalPlan) -> Result { - if let Some(output) = crate::mysql::federated::check(query, self.session.context()) { + async fn do_exec_plan( + &self, + query: &str, + plan: LogicalPlan, + query_ctx: QueryContextRef, + ) -> Result { + if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone()) { Ok(output) } else { - self.query_handler - .do_exec_plan(plan, self.session.context()) - .await + self.query_handler.do_exec_plan(plan, query_ctx).await } } /// Describe the statement - async fn do_describe(&self, statement: Statement) -> Result> { - self.query_handler - .do_describe(statement, self.session.context()) - .await + async fn do_describe( + &self, + statement: Statement, + query_ctx: QueryContextRef, + ) -> Result> { + self.query_handler.do_describe(statement, query_ctx).await } /// Save query and logical plan, return the unique id @@ -200,6 +203,7 @@ impl AsyncMysqlShim for MysqlInstanceShi raw_query: &'a str, w: StatementMetaWriter<'a, W>, ) -> Result<()> { + let query_ctx = self.session.new_query_context(); let (query, param_num) = replace_placeholders(raw_query); let statement = validate_query(raw_query).await?; @@ -208,7 +212,9 @@ impl AsyncMysqlShim for MysqlInstanceShi // in the form of "$i", it can't process "?" right now. let statement = transform_placeholders(statement); - let describe_result = self.do_describe(statement.clone()).await?; + let describe_result = self + .do_describe(statement.clone(), query_ctx.clone()) + .await?; let (plan, schema) = if let Some(DescribeResult { logical_plan, schema, @@ -240,10 +246,7 @@ impl AsyncMysqlShim for MysqlInstanceShi w.reply(stmt_id, ¶ms, &[]).await?; increment_counter!( crate::metrics::METRIC_MYSQL_PREPARED_COUNT, - &[( - crate::metrics::METRIC_DB_LABEL, - self.session.context().get_db_string() - )] + &[(crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string())] ); return Ok(()); } @@ -254,6 +257,7 @@ impl AsyncMysqlShim for MysqlInstanceShi p: ParamParser<'a>, w: QueryResultWriter<'a, W>, ) -> Result<()> { + let query_ctx = self.session.new_query_context(); let _timer = timer!( crate::metrics::METRIC_MYSQL_QUERY_TIMER, &[ @@ -261,10 +265,7 @@ impl AsyncMysqlShim for MysqlInstanceShi crate::metrics::METRIC_MYSQL_SUBPROTOCOL_LABEL, crate::metrics::METRIC_MYSQL_BINQUERY.to_string() ), - ( - crate::metrics::METRIC_DB_LABEL, - self.session.context().get_db_string() - ) + (crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string()) ] ); let params: Vec = p.into_iter().collect(); @@ -294,20 +295,23 @@ impl AsyncMysqlShim for MysqlInstanceShi } let plan = replace_params_with_values(&plan, param_types, params)?; logging::debug!("Mysql execute prepared plan: {}", plan.display_indent()); - let outputs = vec![self.do_exec_plan(&sql_plan.query, plan).await]; + let outputs = vec![ + self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone()) + .await, + ]; (sql_plan.query, outputs) } None => { let query = replace_params(params, sql_plan.query); logging::debug!("Mysql execute replaced query: {}", query); - let outputs = self.do_query(&query).await; + let outputs = self.do_query(&query, query_ctx.clone()).await; (query, outputs) } }; - writer::write_output(w, &query, self.session.context(), outputs).await?; + writer::write_output(w, &query, query_ctx, outputs).await?; Ok(()) } @@ -325,6 +329,7 @@ impl AsyncMysqlShim for MysqlInstanceShi query: &'a str, writer: QueryResultWriter<'a, W>, ) -> Result<()> { + let query_ctx = self.session.new_query_context(); let _timer = timer!( crate::metrics::METRIC_MYSQL_QUERY_TIMER, &[ @@ -332,14 +337,11 @@ impl AsyncMysqlShim for MysqlInstanceShi crate::metrics::METRIC_MYSQL_SUBPROTOCOL_LABEL, crate::metrics::METRIC_MYSQL_TEXTQUERY.to_string() ), - ( - crate::metrics::METRIC_DB_LABEL, - self.session.context().get_db_string() - ) + (crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string()) ] ); - let outputs = self.do_query(query).await; - writer::write_output(writer, query, self.session.context(), outputs).await?; + let outputs = self.do_query(query, query_ctx.clone()).await; + writer::write_output(writer, query, query_ctx, outputs).await?; Ok(()) } @@ -377,9 +379,8 @@ impl AsyncMysqlShim for MysqlInstanceShi } } - let context = self.session.context(); - context.set_current_catalog(catalog); - context.set_current_schema(schema); + self.session.set_catalog(catalog.into()); + self.session.set_schema(schema.into()); w.ok().await.map_err(|e| e.into()) } diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 92dcca078f..5aa543c580 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -118,12 +118,11 @@ fn set_client_info(client: &C, session: &Session) where C: ClientInfo, { - let ctx = session.context(); if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) { - ctx.set_current_catalog(current_catalog); + session.set_catalog(current_catalog.clone()); } if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) { - ctx.set_current_schema(current_schema); + session.set_schema(current_schema.clone()); } if let Some(username) = client.metadata().get(super::METADATA_USER) { session.set_user_info(UserInfo::new(username)); diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 5bdaa63bd7..c1cf5c5a9d 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -46,6 +46,7 @@ impl SimpleQueryHandler for PostgresServerHandler { where C: ClientInfo + Unpin + Send + Sync, { + let query_ctx = self.session.new_query_context(); let _timer = timer!( crate::metrics::METRIC_POSTGRES_QUERY_TIMER, &[ @@ -53,16 +54,10 @@ impl SimpleQueryHandler for PostgresServerHandler { crate::metrics::METRIC_POSTGRES_SUBPROTOCOL_LABEL, crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY.to_string() ), - ( - crate::metrics::METRIC_DB_LABEL, - self.session.context().get_db_string() - ) + (crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string()) ] ); - let outputs = self - .query_handler - .do_query(query, self.session.context()) - .await; + let outputs = self.query_handler.do_query(query, query_ctx).await; let mut results = Vec::with_capacity(outputs.len()); @@ -160,6 +155,7 @@ impl QueryParser for DefaultQueryParser { async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT); + let query_ctx = self.session.new_query_context(); let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; if stmts.len() != 1 { @@ -172,7 +168,7 @@ impl QueryParser for DefaultQueryParser { let stmt = stmts.remove(0); let describe_result = self .query_handler - .do_describe(stmt, self.session.context()) + .do_describe(stmt, query_ctx) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -218,6 +214,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { where C: ClientInfo + Unpin + Send + Sync, { + let query_ctx = self.session.new_query_context(); let _timer = timer!( crate::metrics::METRIC_POSTGRES_QUERY_TIMER, &[ @@ -225,10 +222,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { crate::metrics::METRIC_POSTGRES_SUBPROTOCOL_LABEL, crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY.to_string() ), - ( - crate::metrics::METRIC_DB_LABEL, - self.session.context().get_db_string() - ) + (crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string()) ] ); let sql_plan = portal.statement().statement(); @@ -237,9 +231,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { let plan = plan .replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref()) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - self.query_handler - .do_exec_plan(plan, self.session.context()) - .await + self.query_handler.do_exec_plan(plan, query_ctx).await } else { // manually replace variables in prepared statement when no // logical_plan is generated. This happens when logical plan is not @@ -249,10 +241,7 @@ impl ExtendedQueryHandler for PostgresServerHandler { sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?); } - self.query_handler - .do_query(&sql, self.session.context()) - .await - .remove(0) + self.query_handler.do_query(&sql, query_ctx).await.remove(0) }; output_to_query_response(output, portal.result_column_format()) diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 12b663e21a..34ea190e4d 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -18,6 +18,7 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwap; +use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use crate::context::{Channel, ConnInfo, QueryContext, QueryContextRef, UserInfo}; @@ -25,7 +26,8 @@ use crate::context::{Channel, ConnInfo, QueryContext, QueryContextRef, UserInfo} /// Session for persistent connection such as MySQL, PostgreSQL etc. #[derive(Debug)] pub struct Session { - query_ctx: QueryContextRef, + catalog: ArcSwap, + schema: ArcSwap, user_info: ArcSwap, conn_info: ConnInfo, } @@ -35,19 +37,20 @@ pub type SessionRef = Arc; impl Session { pub fn new(addr: Option, channel: Channel) -> Self { Session { - query_ctx: Arc::new(QueryContext::with_sql_dialect( - DEFAULT_CATALOG_NAME, - DEFAULT_SCHEMA_NAME, - channel.dialect(), - )), + catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.into())), + schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())), user_info: ArcSwap::new(Arc::new(UserInfo::default())), conn_info: ConnInfo::new(addr, channel), } } #[inline] - pub fn context(&self) -> QueryContextRef { - self.query_ctx.clone() + pub fn new_query_context(&self) -> QueryContextRef { + Arc::new(QueryContext::with_sql_dialect( + self.catalog.load().as_ref(), + self.schema.load().as_ref(), + self.conn_info.channel.dialect(), + )) } #[inline] @@ -69,4 +72,18 @@ impl Session { pub fn set_user_info(&self, user_info: UserInfo) { self.user_info.store(Arc::new(user_info)); } + + #[inline] + pub fn set_catalog(&self, catalog: String) { + self.catalog.store(Arc::new(catalog)); + } + + #[inline] + pub fn set_schema(&self, schema: String) { + self.schema.store(Arc::new(schema)); + } + + pub fn get_db_string(&self) -> String { + build_db_string(self.catalog.load().as_ref(), self.schema.load().as_ref()) + } }