refactor: improve semantics of session and query context (#2009)

This commit is contained in:
Ning Sun
2023-07-21 11:50:32 +08:00
committed by GitHub
parent a7557b70f1
commit e0aecc9209
4 changed files with 77 additions and 71 deletions

View File

@@ -32,7 +32,7 @@ use parking_lot::RwLock;
use query::plan::LogicalPlan;
use query::query_engine::DescribeResult;
use rand::RngCore;
use session::context::Channel;
use session::context::{Channel, QueryContextRef};
use session::{Session, SessionRef};
use snafu::{ensure, ResultExt};
use sql::dialect::MySqlDialect;
@@ -89,18 +89,16 @@ impl MysqlInstanceShim {
}
}
async fn do_query(&self, query: &str) -> Vec<Result<Output>> {
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
trace!("Start executing query: '{}'", query);
let start = Instant::now();
let output =
if let Some(output) = crate::mysql::federated::check(query, self.session.context()) {
vec![Ok(output)]
} else {
self.query_handler
.do_query(query, self.session.context())
.await
};
let output = if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone())
{
vec![Ok(output)]
} else {
self.query_handler.do_query(query, query_ctx).await
};
trace!(
"Finished executing query: '{}', total time costs in microseconds: {}",
@@ -111,21 +109,26 @@ impl MysqlInstanceShim {
}
/// Execute the logical plan and return the output
async fn do_exec_plan(&self, query: &str, plan: LogicalPlan) -> Result<Output> {
if let Some(output) = crate::mysql::federated::check(query, self.session.context()) {
async fn do_exec_plan(
&self,
query: &str,
plan: LogicalPlan,
query_ctx: QueryContextRef,
) -> Result<Output> {
if let Some(output) = crate::mysql::federated::check(query, query_ctx.clone()) {
Ok(output)
} else {
self.query_handler
.do_exec_plan(plan, self.session.context())
.await
self.query_handler.do_exec_plan(plan, query_ctx).await
}
}
/// Describe the statement
async fn do_describe(&self, statement: Statement) -> Result<Option<DescribeResult>> {
self.query_handler
.do_describe(statement, self.session.context())
.await
async fn do_describe(
&self,
statement: Statement,
query_ctx: QueryContextRef,
) -> Result<Option<DescribeResult>> {
self.query_handler.do_describe(statement, query_ctx).await
}
/// Save query and logical plan, return the unique id
@@ -200,6 +203,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
raw_query: &'a str,
w: StatementMetaWriter<'a, W>,
) -> Result<()> {
let query_ctx = self.session.new_query_context();
let (query, param_num) = replace_placeholders(raw_query);
let statement = validate_query(raw_query).await?;
@@ -208,7 +212,9 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
// in the form of "$i", it can't process "?" right now.
let statement = transform_placeholders(statement);
let describe_result = self.do_describe(statement.clone()).await?;
let describe_result = self
.do_describe(statement.clone(), query_ctx.clone())
.await?;
let (plan, schema) = if let Some(DescribeResult {
logical_plan,
schema,
@@ -240,10 +246,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
w.reply(stmt_id, &params, &[]).await?;
increment_counter!(
crate::metrics::METRIC_MYSQL_PREPARED_COUNT,
&[(
crate::metrics::METRIC_DB_LABEL,
self.session.context().get_db_string()
)]
&[(crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string())]
);
return Ok(());
}
@@ -254,6 +257,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
p: ParamParser<'a>,
w: QueryResultWriter<'a, W>,
) -> Result<()> {
let query_ctx = self.session.new_query_context();
let _timer = timer!(
crate::metrics::METRIC_MYSQL_QUERY_TIMER,
&[
@@ -261,10 +265,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
crate::metrics::METRIC_MYSQL_SUBPROTOCOL_LABEL,
crate::metrics::METRIC_MYSQL_BINQUERY.to_string()
),
(
crate::metrics::METRIC_DB_LABEL,
self.session.context().get_db_string()
)
(crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string())
]
);
let params: Vec<ParamValue> = p.into_iter().collect();
@@ -294,20 +295,23 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
}
let plan = replace_params_with_values(&plan, param_types, params)?;
logging::debug!("Mysql execute prepared plan: {}", plan.display_indent());
let outputs = vec![self.do_exec_plan(&sql_plan.query, plan).await];
let outputs = vec![
self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone())
.await,
];
(sql_plan.query, outputs)
}
None => {
let query = replace_params(params, sql_plan.query);
logging::debug!("Mysql execute replaced query: {}", query);
let outputs = self.do_query(&query).await;
let outputs = self.do_query(&query, query_ctx.clone()).await;
(query, outputs)
}
};
writer::write_output(w, &query, self.session.context(), outputs).await?;
writer::write_output(w, &query, query_ctx, outputs).await?;
Ok(())
}
@@ -325,6 +329,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
query: &'a str,
writer: QueryResultWriter<'a, W>,
) -> Result<()> {
let query_ctx = self.session.new_query_context();
let _timer = timer!(
crate::metrics::METRIC_MYSQL_QUERY_TIMER,
&[
@@ -332,14 +337,11 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
crate::metrics::METRIC_MYSQL_SUBPROTOCOL_LABEL,
crate::metrics::METRIC_MYSQL_TEXTQUERY.to_string()
),
(
crate::metrics::METRIC_DB_LABEL,
self.session.context().get_db_string()
)
(crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string())
]
);
let outputs = self.do_query(query).await;
writer::write_output(writer, query, self.session.context(), outputs).await?;
let outputs = self.do_query(query, query_ctx.clone()).await;
writer::write_output(writer, query, query_ctx, outputs).await?;
Ok(())
}
@@ -377,9 +379,8 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
}
}
let context = self.session.context();
context.set_current_catalog(catalog);
context.set_current_schema(schema);
self.session.set_catalog(catalog.into());
self.session.set_schema(schema.into());
w.ok().await.map_err(|e| e.into())
}

View File

@@ -118,12 +118,11 @@ fn set_client_info<C>(client: &C, session: &Session)
where
C: ClientInfo,
{
let ctx = session.context();
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
ctx.set_current_catalog(current_catalog);
session.set_catalog(current_catalog.clone());
}
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
ctx.set_current_schema(current_schema);
session.set_schema(current_schema.clone());
}
if let Some(username) = client.metadata().get(super::METADATA_USER) {
session.set_user_info(UserInfo::new(username));

View File

@@ -46,6 +46,7 @@ impl SimpleQueryHandler for PostgresServerHandler {
where
C: ClientInfo + Unpin + Send + Sync,
{
let query_ctx = self.session.new_query_context();
let _timer = timer!(
crate::metrics::METRIC_POSTGRES_QUERY_TIMER,
&[
@@ -53,16 +54,10 @@ impl SimpleQueryHandler for PostgresServerHandler {
crate::metrics::METRIC_POSTGRES_SUBPROTOCOL_LABEL,
crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY.to_string()
),
(
crate::metrics::METRIC_DB_LABEL,
self.session.context().get_db_string()
)
(crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string())
]
);
let outputs = self
.query_handler
.do_query(query, self.session.context())
.await;
let outputs = self.query_handler.do_query(query, query_ctx).await;
let mut results = Vec::with_capacity(outputs.len());
@@ -160,6 +155,7 @@ impl QueryParser for DefaultQueryParser {
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT);
let query_ctx = self.session.new_query_context();
let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {})
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
if stmts.len() != 1 {
@@ -172,7 +168,7 @@ impl QueryParser for DefaultQueryParser {
let stmt = stmts.remove(0);
let describe_result = self
.query_handler
.do_describe(stmt, self.session.context())
.do_describe(stmt, query_ctx)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -218,6 +214,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
where
C: ClientInfo + Unpin + Send + Sync,
{
let query_ctx = self.session.new_query_context();
let _timer = timer!(
crate::metrics::METRIC_POSTGRES_QUERY_TIMER,
&[
@@ -225,10 +222,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
crate::metrics::METRIC_POSTGRES_SUBPROTOCOL_LABEL,
crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY.to_string()
),
(
crate::metrics::METRIC_DB_LABEL,
self.session.context().get_db_string()
)
(crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string())
]
);
let sql_plan = portal.statement().statement();
@@ -237,9 +231,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
let plan = plan
.replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
self.query_handler
.do_exec_plan(plan, self.session.context())
.await
self.query_handler.do_exec_plan(plan, query_ctx).await
} else {
// manually replace variables in prepared statement when no
// logical_plan is generated. This happens when logical plan is not
@@ -249,10 +241,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
}
self.query_handler
.do_query(&sql, self.session.context())
.await
.remove(0)
self.query_handler.do_query(&sql, query_ctx).await.remove(0)
};
output_to_query_response(output, portal.result_column_format())

View File

@@ -18,6 +18,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use arc_swap::ArcSwap;
use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use crate::context::{Channel, ConnInfo, QueryContext, QueryContextRef, UserInfo};
@@ -25,7 +26,8 @@ use crate::context::{Channel, ConnInfo, QueryContext, QueryContextRef, UserInfo}
/// Session for persistent connection such as MySQL, PostgreSQL etc.
#[derive(Debug)]
pub struct Session {
query_ctx: QueryContextRef,
catalog: ArcSwap<String>,
schema: ArcSwap<String>,
user_info: ArcSwap<UserInfo>,
conn_info: ConnInfo,
}
@@ -35,19 +37,20 @@ pub type SessionRef = Arc<Session>;
impl Session {
pub fn new(addr: Option<SocketAddr>, channel: Channel) -> Self {
Session {
query_ctx: Arc::new(QueryContext::with_sql_dialect(
DEFAULT_CATALOG_NAME,
DEFAULT_SCHEMA_NAME,
channel.dialect(),
)),
catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.into())),
schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())),
user_info: ArcSwap::new(Arc::new(UserInfo::default())),
conn_info: ConnInfo::new(addr, channel),
}
}
#[inline]
pub fn context(&self) -> QueryContextRef {
self.query_ctx.clone()
pub fn new_query_context(&self) -> QueryContextRef {
Arc::new(QueryContext::with_sql_dialect(
self.catalog.load().as_ref(),
self.schema.load().as_ref(),
self.conn_info.channel.dialect(),
))
}
#[inline]
@@ -69,4 +72,18 @@ impl Session {
pub fn set_user_info(&self, user_info: UserInfo) {
self.user_info.store(Arc::new(user_info));
}
#[inline]
pub fn set_catalog(&self, catalog: String) {
self.catalog.store(Arc::new(catalog));
}
#[inline]
pub fn set_schema(&self, schema: String) {
self.schema.store(Arc::new(schema));
}
pub fn get_db_string(&self) -> String {
build_db_string(self.catalog.load().as_ref(), self.schema.load().as_ref())
}
}