1use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use datatypes::schema::Schema;
32use itertools::Itertools;
33use mysql_common::Value as MysqlValue;
34use opensrv_mysql::{
35 AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
36 StatementMetaWriter, ValueInner,
37};
38use parking_lot::RwLock;
39use query::planner::DfLogicalPlanner;
40use query::query_engine::DescribeResult;
41use rand::RngCore;
42use session::context::{Channel, QueryContextRef};
43use session::{Session, SessionRef};
44use snafu::{ResultExt, ensure};
45use sql::dialect::MySqlDialect;
46use sql::parser::{ParseOptions, ParserContext};
47use sql::statements::statement::Statement;
48use tokio::io::AsyncWrite;
49
50use crate::SqlPlan;
51use crate::error::{
52 self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
53};
54use crate::metrics::METRIC_AUTH_FAILURE;
55use crate::mysql::helper::{self, format_placeholder, transform_placeholders_with_count};
56use crate::mysql::writer;
57use crate::mysql::writer::{create_mysql_column, handle_err};
58use crate::query_handler::sql::ServerSqlQueryHandlerRef;
59
60const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
61const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
62
63enum Params<'a> {
65 ProtocolParams(Vec<ParamValue<'a>>),
67 CliParams(Vec<sql::ast::Expr>),
69}
70
71impl Params<'_> {
72 fn len(&self) -> usize {
73 match self {
74 Params::ProtocolParams(params) => params.len(),
75 Params::CliParams(params) => params.len(),
76 }
77 }
78}
79
80pub struct MysqlInstanceShim {
82 query_handler: ServerSqlQueryHandlerRef,
83 salt: [u8; 20],
84 session: SessionRef,
85 user_provider: Option<UserProviderRef>,
86 prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
87 prepared_stmts_counter: AtomicU32,
88 process_id: u32,
89 prepared_stmt_cache_size: usize,
90}
91
92impl MysqlInstanceShim {
93 pub fn create(
94 query_handler: ServerSqlQueryHandlerRef,
95 user_provider: Option<UserProviderRef>,
96 client_addr: SocketAddr,
97 process_id: u32,
98 prepared_stmt_cache_size: usize,
99 ) -> MysqlInstanceShim {
100 let mut bs = vec![0u8; 20];
102 let mut rng = rand::rng();
103 rng.fill_bytes(bs.as_mut());
104
105 let mut scramble: [u8; 20] = [0; 20];
106 for i in 0..20 {
107 scramble[i] = bs[i] & 0x7fu8;
108 if scramble[i] == b'\0' || scramble[i] == b'$' {
109 scramble[i] += 1;
110 }
111 }
112
113 MysqlInstanceShim {
114 query_handler,
115 salt: scramble,
116 session: Arc::new(Session::new(
117 Some(client_addr),
118 Channel::Mysql,
119 Default::default(),
120 process_id,
121 )),
122 user_provider,
123 prepared_stmts: Default::default(),
124 prepared_stmts_counter: AtomicU32::new(1),
125 process_id,
126 prepared_stmt_cache_size,
127 }
128 }
129
130 #[tracing::instrument(skip_all, name = "mysql::do_query")]
131 async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
132 if let Some(output) =
133 crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
134 {
135 vec![Ok(output)]
136 } else {
137 self.query_handler.do_query(query, query_ctx.clone()).await
138 }
139 }
140
141 async fn do_describe(
143 &self,
144 statement: Statement,
145 query_ctx: QueryContextRef,
146 ) -> Result<Option<DescribeResult>> {
147 self.query_handler.do_describe(statement, query_ctx).await
148 }
149
150 fn save_plan(&self, plan: SqlPlan, stmt_key: String) -> Result<()> {
152 let mut prepared_stmts = self.prepared_stmts.write();
153 let max_capacity = self.prepared_stmt_cache_size;
154
155 let is_update = prepared_stmts.contains_key(&stmt_key);
156
157 if !is_update && prepared_stmts.len() >= max_capacity {
158 return error::InternalSnafu {
159 err_msg: format!(
160 "Prepared statement cache is full, max capacity: {}",
161 max_capacity
162 ),
163 }
164 .fail();
165 }
166
167 let _ = prepared_stmts.insert(stmt_key, plan);
168 Ok(())
169 }
170
171 fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
173 let guard = self.prepared_stmts.read();
174 guard.get(stmt_key).cloned()
175 }
176
177 async fn do_prepare(
179 &mut self,
180 raw_query: &str,
181 query_ctx: QueryContextRef,
182 stmt_key: String,
183 ) -> Result<(Vec<Column>, Vec<Column>)> {
184 if crate::mysql::federated::check(raw_query, query_ctx.clone(), self.session.clone())
185 .is_some()
186 {
187 self.save_plan(SqlPlan::Shortcut(raw_query.to_string()), stmt_key)
188 .inspect_err(|e| {
189 error!(e; "Failed to save prepared statement");
190 })?;
191 return Ok((vec![], vec![]));
192 }
193
194 let statement = validate_query(raw_query).await?;
195
196 let (statement, placeholder_count) = transform_placeholders_with_count(statement);
199 let param_num = placeholder_count + 1;
200
201 let describe_result = self
202 .do_describe(statement.clone(), query_ctx.clone())
203 .await?;
204 let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan);
205
206 let (params, can_cache_as_plan) = if let Some(plan) = &plan {
207 let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
208 .context(InferParameterTypesSnafu)?
209 .into_iter()
210 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
211 .collect();
212
213 (
214 prepared_params(¶m_types, param_num)?,
215 all_params_have_types(¶m_types, param_num),
216 )
217 } else {
218 (dummy_params(param_num)?, false)
219 };
220
221 let columns =
222 plan.as_ref()
223 .map(|plan| {
224 let schema: Schema = plan.schema().clone().try_into().map_err(
225 |e: datatypes::error::Error| {
226 error::InternalSnafu {
227 err_msg: e.to_string(),
228 }
229 .build()
230 },
231 )?;
232 schema
233 .column_schemas()
234 .iter()
235 .map(|column_schema| {
236 create_mysql_column(&column_schema.data_type, &column_schema.name)
237 })
238 .collect::<Result<Vec<_>>>()
239 })
240 .transpose()?
241 .unwrap_or_default();
242
243 match plan {
244 Some(plan) if can_cache_as_plan => {
245 self.save_plan(SqlPlan::Plan(plan, statement), stmt_key)
246 .inspect_err(|e| {
247 error!(e; "Failed to save prepared statement");
248 })?;
249 }
250 _ => {
251 self.save_plan(
252 SqlPlan::Statement(statement, raw_query.to_string()),
253 stmt_key,
254 )
255 .inspect_err(|e| {
256 error!(e; "Failed to save prepared statement");
257 })?;
258 }
259 }
260
261 Ok((params, columns))
262 }
263
264 async fn do_execute(
265 &mut self,
266 query_ctx: QueryContextRef,
267 stmt_key: String,
268 params: Params<'_>,
269 ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
270 let sql_plan = match self.plan(&stmt_key) {
271 None => {
272 return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
273 }
274 Some(sql_plan) => sql_plan,
275 };
276
277 let outputs = match sql_plan {
278 SqlPlan::Plan(plan, stmt) => {
279 let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
280 .context(InferParameterTypesSnafu)?
281 .into_iter()
282 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
283 .collect::<HashMap<_, _>>();
284
285 if params.len() != param_types.len() {
286 return error::InternalSnafu {
287 err_msg: "Prepare statement params number mismatch".to_string(),
288 }
289 .fail();
290 }
291
292 let replaced_plan = match params {
293 Params::ProtocolParams(params) => {
294 replace_params_with_values(&plan, param_types, ¶ms)
295 }
296 Params::CliParams(params) => {
297 replace_params_with_exprs(&plan, param_types, ¶ms)
298 }
299 }?;
300
301 debug!(
302 "Mysql execute prepared plan: {}",
303 replaced_plan.display_indent()
304 );
305 vec![
306 self.query_handler
307 .do_exec_plan(replaced_plan, Some(stmt), query_ctx.clone())
308 .await,
309 ]
310 }
311 SqlPlan::Shortcut(query) => {
312 if let Some(output) =
313 crate::mysql::federated::check(&query, query_ctx.clone(), self.session.clone())
314 {
315 vec![Ok(output)]
316 } else {
317 self.do_query(&query, query_ctx.clone()).await
318 }
319 }
320 SqlPlan::Statement(stmt, query) => {
321 let param_strs = match params {
322 Params::ProtocolParams(params) => {
323 params.iter().map(convert_param_value_to_string).collect()
324 }
325 Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
326 };
327 debug!(
328 "do_execute Replacing with Params: {:?}, Original Query: {}",
329 param_strs, query
330 );
331 let query = replace_params(param_strs, stmt, query)?;
332 debug!("Mysql execute replaced query: {}", query);
333 self.do_query(&query, query_ctx.clone()).await
334 }
335 _ => {
336 return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
337 }
338 };
339
340 Ok(outputs)
341 }
342
343 fn do_close(&mut self, stmt_key: String) {
345 let mut guard = self.prepared_stmts.write();
346 let _ = guard.remove(&stmt_key);
347 }
348
349 fn auth_plugin(&self) -> &'static str {
350 if self
351 .user_provider
352 .as_ref()
353 .map(|x| x.external())
354 .unwrap_or(false)
355 {
356 MYSQL_CLEAR_PASSWORD
357 } else {
358 MYSQL_NATIVE_PASSWORD
359 }
360 }
361}
362
363#[async_trait]
364impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
365 type Error = error::Error;
366
367 fn version(&self) -> String {
368 std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
369 }
370
371 fn connect_id(&self) -> u32 {
372 self.process_id
373 }
374
375 fn default_auth_plugin(&self) -> &str {
376 self.auth_plugin()
377 }
378
379 async fn auth_plugin_for_username(&self, _user: &[u8]) -> &'static str {
380 self.auth_plugin()
381 }
382
383 fn salt(&self) -> [u8; 20] {
384 self.salt
385 }
386
387 async fn authenticate(
388 &self,
389 auth_plugin: &str,
390 username: &[u8],
391 salt: &[u8],
392 auth_data: &[u8],
393 ) -> bool {
394 let username = String::from_utf8_lossy(username);
396
397 let mut user_info = None;
398 let addr = self
399 .session
400 .conn_info()
401 .client_addr
402 .map(|addr| addr.to_string());
403 if let Some(user_provider) = &self.user_provider {
404 let user_id = Identity::UserId(&username, addr.as_deref());
405
406 let password = match auth_plugin {
407 MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
408 MYSQL_CLEAR_PASSWORD => {
409 let password = if let &[password @ .., 0] = &auth_data {
412 password
413 } else {
414 auth_data
415 };
416 Password::PlainText(String::from_utf8_lossy(password).to_string().into())
417 }
418 other => {
419 error!("Unsupported mysql auth plugin: {}", other);
420 return false;
421 }
422 };
423 match user_provider.authenticate(user_id, password).await {
424 Ok(userinfo) => {
425 user_info = Some(userinfo);
426 }
427 Err(e) => {
428 METRIC_AUTH_FAILURE
429 .with_label_values(&[e.status_code().as_ref()])
430 .inc();
431 warn!(e; "Failed to auth");
432 return false;
433 }
434 };
435 }
436 let user_info =
437 user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
438
439 self.session.set_user_info(user_info);
440
441 true
442 }
443
444 async fn on_prepare<'a>(
445 &'a mut self,
446 raw_query: &'a str,
447 w: StatementMetaWriter<'a, W>,
448 ) -> Result<()> {
449 let query_ctx = self.session.new_query_context();
450 let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
451 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
452 let (params, columns) = match self
453 .do_prepare(raw_query, query_ctx.clone(), stmt_key)
454 .await
455 {
456 Ok(x) => x,
457 Err(e) => {
458 let (kind, msg) = handle_err(e, query_ctx.clone());
459 w.error(kind, msg.as_bytes()).await?;
460 return Ok(());
461 }
462 };
463 debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
464 w.reply(stmt_id, ¶ms, &columns).await?;
465 crate::metrics::METRIC_MYSQL_PREPARED_COUNT
466 .with_label_values(&[query_ctx.get_db_string().as_str()])
467 .inc();
468 return Ok(());
469 }
470
471 async fn on_execute<'a>(
472 &'a mut self,
473 stmt_id: u32,
474 p: ParamParser<'a>,
475 w: QueryResultWriter<'a, W>,
476 ) -> Result<()> {
477 self.session.clear_warnings();
478
479 let query_ctx = self.session.new_query_context();
480 let db = query_ctx.get_db_string();
481 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
482 .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
483 .start_timer();
484
485 let params: Vec<ParamValue> = p.into_iter().collect();
486 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
487
488 let outputs = match self
489 .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
490 .await
491 {
492 Ok(outputs) => outputs,
493 Err(e) => {
494 let (kind, err) = handle_err(e, query_ctx);
495 debug!(
496 "Failed to execute prepared statement, kind: {:?}, err: {}",
497 kind, err
498 );
499 w.error(kind, err.as_bytes()).await?;
500 return Ok(());
501 }
502 };
503
504 writer::write_output(w, query_ctx, self.session.clone(), outputs).await?;
505
506 Ok(())
507 }
508
509 async fn on_close<'a>(&'a mut self, stmt_id: u32)
510 where
511 W: 'async_trait,
512 {
513 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
514 self.do_close(stmt_key);
515 }
516
517 #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
518 async fn on_query<'a>(
519 &'a mut self,
520 query: &'a str,
521 writer: QueryResultWriter<'a, W>,
522 ) -> Result<()> {
523 let query_ctx = self.session.new_query_context();
524 let db = query_ctx.get_db_string();
525 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
526 .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
527 .start_timer();
528
529 let query_upcase = query.to_uppercase();
531 if !query_upcase.starts_with("SHOW WARNINGS") {
532 self.session.clear_warnings();
533 }
534
535 if query_upcase.starts_with("PREPARE ") {
536 match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
537 Ok((stmt_name, stmt)) => {
538 let prepare_results =
539 self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
540 match prepare_results {
541 Ok(_) => {
542 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
543 writer::write_output(writer, query_ctx, self.session.clone(), outputs)
544 .await?;
545 return Ok(());
546 }
547 Err(e) => {
548 writer
549 .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
550 .await?;
551 return Ok(());
552 }
553 }
554 }
555 Err(e) => {
556 writer
557 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
558 .await?;
559 return Ok(());
560 }
561 }
562 } else if query_upcase.starts_with("EXECUTE ") {
563 match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
564 Ok((stmt_name, params)) => {
565 let outputs = match self
566 .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
567 .await
568 {
569 Ok(outputs) => outputs,
570 Err(e) => {
571 let (kind, err) = handle_err(e, query_ctx);
572 debug!(
573 "Failed to execute prepared statement, kind: {:?}, err: {}",
574 kind, err
575 );
576 writer.error(kind, err.as_bytes()).await?;
577 return Ok(());
578 }
579 };
580 writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
581
582 return Ok(());
583 }
584 Err(e) => {
585 writer
586 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
587 .await?;
588 return Ok(());
589 }
590 }
591 } else if query_upcase.starts_with("DEALLOCATE ") {
592 match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
593 Ok(stmt_name) => {
594 self.do_close(stmt_name);
595 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
596 writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
597 return Ok(());
598 }
599 Err(e) => {
600 writer
601 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
602 .await?;
603 return Ok(());
604 }
605 }
606 }
607
608 let outputs = self.do_query(query, query_ctx.clone()).await;
609 writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
610
611 Ok(())
612 }
613
614 async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
615 let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
616 let catalog = if let Some(catalog) = &catalog_from_db {
617 catalog.clone()
618 } else {
619 self.session.catalog()
620 };
621
622 if !self
623 .query_handler
624 .is_valid_schema(&catalog, &schema)
625 .await?
626 {
627 return w
628 .error(
629 ErrorKind::ER_WRONG_DB_NAME,
630 format!("Unknown database '{}'", database).as_bytes(),
631 )
632 .await
633 .map_err(|e| e.into());
634 }
635
636 let user_info = &self.session.user_info();
637
638 if let Some(schema_validator) = &self.user_provider
639 && let Err(e) = schema_validator
640 .authorize(&catalog, &schema, user_info)
641 .await
642 {
643 METRIC_AUTH_FAILURE
644 .with_label_values(&[e.status_code().as_ref()])
645 .inc();
646 return w
647 .error(
648 ErrorKind::ER_DBACCESS_DENIED_ERROR,
649 e.output_msg().as_bytes(),
650 )
651 .await
652 .map_err(|e| e.into());
653 }
654
655 if catalog_from_db.is_some() {
656 self.session.set_catalog(catalog)
657 }
658 self.session.set_schema(schema);
659
660 w.ok().await.map_err(|e| e.into())
661 }
662}
663
664fn convert_param_value_to_string(param: &ParamValue) -> String {
665 match param.value.into_inner() {
666 ValueInner::Int(u) => u.to_string(),
667 ValueInner::UInt(u) => u.to_string(),
668 ValueInner::Double(u) => u.to_string(),
669 ValueInner::NULL => "NULL".to_string(),
670 ValueInner::Bytes(b) => MysqlValue::Bytes(b.to_vec()).as_sql(false),
675 ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
676 ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
677 ValueInner::Time(_) => format_duration(Duration::from(param.value)),
678 }
679}
680
681fn replace_params(params: Vec<String>, stmt: Statement, mut query: String) -> Result<String> {
682 let spans = helper::placeholder_spans(stmt);
683 ensure!(
684 spans.len() == params.len(),
685 error::InternalSnafu {
686 err_msg: format!(
687 "Prepared statement expected {} parameters but got {}",
688 spans.len(),
689 params.len()
690 )
691 }
692 );
693
694 let mut replacements = Vec::with_capacity(spans.len());
695 for span in spans {
696 let start = location_to_byte_offset(&query, span.start_line, span.start_column)
697 .ok_or_else(|| {
698 error::InternalSnafu {
699 err_msg: format!(
700 "Invalid placeholder start span: line {}, column {}",
701 span.start_line, span.start_column
702 ),
703 }
704 .build()
705 })?;
706 let end =
707 location_to_byte_offset(&query, span.end_line, span.end_column).ok_or_else(|| {
708 error::InternalSnafu {
709 err_msg: format!(
710 "Invalid placeholder end span: line {}, column {}",
711 span.end_line, span.end_column
712 ),
713 }
714 .build()
715 })?;
716 let param = span
717 .index
718 .checked_sub(1)
719 .and_then(|idx| params.get(idx))
720 .ok_or_else(|| {
721 error::InternalSnafu {
722 err_msg: format!("Missing prepared statement parameter {}", span.index),
723 }
724 .build()
725 })?;
726
727 ensure!(
728 start < end && end <= query.len(),
729 error::InternalSnafu {
730 err_msg: format!(
731 "Invalid placeholder byte span: {}..{} for query length {}",
732 start,
733 end,
734 query.len()
735 )
736 }
737 );
738 ensure!(
739 query.get(start..end) == Some("?"),
740 error::InternalSnafu {
741 err_msg: format!(
742 "Prepared statement placeholder span maps to {:?} instead of '?'",
743 query.get(start..end)
744 )
745 }
746 );
747
748 replacements.push((start, end, param.clone()));
749 }
750
751 replacements.sort_unstable_by_key(|(start, _, _)| *start);
752 for windows in replacements.windows(2) {
753 ensure!(
754 windows[0].1 <= windows[1].0,
755 error::InternalSnafu {
756 err_msg: "Overlapping placeholder spans in prepared statement".to_string()
757 }
758 );
759 }
760
761 for (start, end, param) in replacements.into_iter().rev() {
765 query.replace_range(start..end, ¶m);
766 }
767
768 Ok(query)
769}
770
771fn location_to_byte_offset(query: &str, line: u64, column: u64) -> Option<usize> {
772 if line == 0 || column == 0 {
776 return None;
777 }
778
779 let mut current_line = 1;
780 let mut current_column = 1;
781 for (index, ch) in query.char_indices() {
782 if current_line == line && current_column == column {
783 return Some(index);
784 }
785
786 if ch == '\n' {
787 current_line += 1;
788 current_column = 1;
789 } else {
790 current_column += 1;
791 }
792 }
793
794 (current_line == line && current_column == column).then_some(query.len())
797}
798
799fn format_duration(duration: Duration) -> String {
800 let seconds = duration.as_secs() % 60;
801 let minutes = (duration.as_secs() / 60) % 60;
802 let hours = (duration.as_secs() / 60) / 60;
803 format!("'{}:{}:{}'", hours, minutes, seconds)
804}
805
806fn replace_params_with_values(
807 plan: &LogicalPlan,
808 param_types: HashMap<String, Option<ConcreteDataType>>,
809 params: &[ParamValue],
810) -> Result<LogicalPlan> {
811 debug_assert_eq!(param_types.len(), params.len());
812
813 debug!(
814 "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
815 param_types,
816 params
817 .iter()
818 .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
819 .join(", "),
820 plan
821 );
822
823 let mut values = Vec::with_capacity(params.len());
824
825 for (i, param) in params.iter().enumerate() {
826 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
827 let value = helper::convert_value(param, t)?;
828
829 values.push(value.into());
830 }
831 }
832
833 plan.clone()
834 .replace_params_with_values(&ParamValues::List(values.clone()))
835 .context(DataFrameSnafu)
836}
837
838fn replace_params_with_exprs(
839 plan: &LogicalPlan,
840 param_types: HashMap<String, Option<ConcreteDataType>>,
841 params: &[sql::ast::Expr],
842) -> Result<LogicalPlan> {
843 debug_assert_eq!(param_types.len(), params.len());
844
845 debug!(
846 "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
847 param_types,
848 params.iter().map(|x| format!("({:?})", x)).join(", "),
849 plan
850 );
851
852 let mut values = Vec::with_capacity(params.len());
853
854 for (i, param) in params.iter().enumerate() {
855 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
856 let value = helper::convert_expr_to_scalar_value(param, t)?;
857
858 values.push(value.into());
859 }
860 }
861
862 plan.clone()
863 .replace_params_with_values(&ParamValues::List(values.clone()))
864 .context(DataFrameSnafu)
865}
866
867async fn validate_query(query: &str) -> Result<Statement> {
868 let statement =
869 ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
870 let mut statement = statement.map_err(|e| {
871 InvalidPrepareStatementSnafu {
872 err_msg: e.output_msg(),
873 }
874 .build()
875 })?;
876
877 ensure!(
878 statement.len() == 1,
879 InvalidPrepareStatementSnafu {
880 err_msg: "prepare statement only support single statement".to_string(),
881 }
882 );
883
884 let statement = statement.remove(0);
885
886 Ok(statement)
887}
888
889fn dummy_params(index: usize) -> Result<Vec<Column>> {
890 let mut params = Vec::with_capacity(index - 1);
891
892 for _ in 1..index {
893 params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
894 }
895
896 Ok(params)
897}
898
899fn prepared_params(
901 param_types: &HashMap<String, Option<ConcreteDataType>>,
902 param_num: usize,
903) -> Result<Vec<Column>> {
904 let mut params = Vec::with_capacity(param_num - 1);
905
906 for i in 1..param_num {
908 let column = if let Some(Some(t)) = param_types.get(&format_placeholder(i)) {
909 create_mysql_column(t, "")?
910 } else {
911 create_mysql_column(&ConcreteDataType::null_datatype(), "")?
912 };
913 params.push(column);
914 }
915
916 Ok(params)
917}
918
919fn all_params_have_types(
920 param_types: &HashMap<String, Option<ConcreteDataType>>,
921 param_num: usize,
922) -> bool {
923 param_types.len() == param_num - 1
924 && (1..param_num).all(|i| matches!(param_types.get(&format_placeholder(i)), Some(Some(_))))
925}
926
927#[cfg(test)]
928mod tests {
929 use std::sync::Arc;
930
931 use async_trait::async_trait;
932 use common_query::Output;
933 use datafusion_expr::LogicalPlan;
934 use query::parser::PromQuery;
935 use query::query_engine::DescribeResult;
936 use session::context::QueryContext;
937 use sql::statements::statement::Statement;
938
939 use super::*;
940 use crate::error::Result;
941 use crate::query_handler::sql::SqlQueryHandler;
942
943 struct DummyQueryHandler;
944
945 #[async_trait]
946 impl SqlQueryHandler for DummyQueryHandler {
947 async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
948 unimplemented!()
949 }
950
951 async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
952 unimplemented!()
953 }
954
955 async fn do_exec_plan(
956 &self,
957 _: LogicalPlan,
958 _: Option<Statement>,
959 _: QueryContextRef,
960 ) -> Result<Output> {
961 unimplemented!()
962 }
963
964 async fn do_describe(
965 &self,
966 _: Statement,
967 _: QueryContextRef,
968 ) -> Result<Option<DescribeResult>> {
969 unimplemented!()
970 }
971
972 async fn is_valid_schema(&self, _: &str, _: &str) -> Result<bool> {
973 Ok(true)
974 }
975 }
976
977 fn create_shim() -> MysqlInstanceShim {
978 MysqlInstanceShim::create(
979 Arc::new(DummyQueryHandler),
980 None,
981 "127.0.0.1:3306".parse().unwrap(),
982 1,
983 1024,
984 )
985 }
986
987 fn statement_with_transformed_placeholders(query: &str) -> Statement {
988 let mut statements =
989 ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default())
990 .unwrap();
991 assert_eq!(statements.len(), 1);
992 transform_placeholders_with_count(statements.remove(0)).0
993 }
994
995 #[test]
996 fn test_prepared_params_keep_unknown_type_placeholders() {
997 let mut param_types = HashMap::new();
998 param_types.insert(format_placeholder(1), None);
999 param_types.insert(
1000 format_placeholder(2),
1001 Some(ConcreteDataType::int32_datatype()),
1002 );
1003
1004 let params = prepared_params(¶m_types, 3).unwrap();
1005 assert_eq!(params.len(), 2);
1006 assert!(!all_params_have_types(¶m_types, 3));
1007 }
1008
1009 #[test]
1010 fn test_replace_params_by_placeholder_span() {
1011 let query = "SELECT ?, ?".to_string();
1012 let stmt = statement_with_transformed_placeholders(&query);
1013 let params = vec!["'$2 should stay'".to_string(), "'value'".to_string()];
1014
1015 assert_eq!(
1016 "SELECT '$2 should stay', 'value'",
1017 replace_params(params, stmt, query).unwrap()
1018 );
1019
1020 let query = "SELECT ?, ?, ?".to_string();
1021 let stmt = statement_with_transformed_placeholders(&query);
1022 let params = vec![
1023 "'much longer than a placeholder'".to_string(),
1024 "0".to_string(),
1025 "'also much longer than a placeholder'".to_string(),
1026 ];
1027
1028 assert_eq!(
1029 "SELECT 'much longer than a placeholder', 0, 'also much longer than a placeholder'",
1030 replace_params(params, stmt, query).unwrap()
1031 );
1032
1033 let query = "SELECT '$1', \"$2\", `$3`, ?, ?".to_string();
1034 let stmt = statement_with_transformed_placeholders(&query);
1035 let params = vec!["'1'".to_string(), "'2'".to_string()];
1036
1037 assert_eq!(
1038 "SELECT '$1', \"$2\", `$3`, '1', '2'",
1039 replace_params(params, stmt, query).unwrap()
1040 );
1041
1042 let query = "SELECT /* ? */ ? -- ?\n, ?".to_string();
1043 let stmt = statement_with_transformed_placeholders(&query);
1044 let params = vec!["'first'".to_string(), "'second'".to_string()];
1045
1046 assert_eq!(
1047 "SELECT /* ? */ 'first' -- ?\n, 'second'",
1048 replace_params(params, stmt, query).unwrap()
1049 );
1050
1051 let query = "SELECT '中文', ?".to_string();
1052 let stmt = statement_with_transformed_placeholders(&query);
1053 let params = vec!["'value'".to_string()];
1054
1055 assert_eq!(
1056 "SELECT '中文', 'value'",
1057 replace_params(params, stmt, query).unwrap()
1058 );
1059
1060 let query = "SELECT '中文',\n ?".to_string();
1061 let stmt = statement_with_transformed_placeholders(&query);
1062 let params = vec!["'value'".to_string()];
1063
1064 assert_eq!(
1065 "SELECT '中文',\n 'value'",
1066 replace_params(params, stmt, query).unwrap()
1067 );
1068
1069 let query = "SELECT 'x'\r\n, ?".to_string();
1070 let stmt = statement_with_transformed_placeholders(&query);
1071 let params = vec!["'crlf'".to_string()];
1072
1073 assert_eq!(
1074 "SELECT 'x'\r\n, 'crlf'",
1075 replace_params(params, stmt, query).unwrap()
1076 );
1077
1078 let query = "SELECT\t?".to_string();
1079 let stmt = statement_with_transformed_placeholders(&query);
1080 let params = vec!["NULL".to_string()];
1081
1082 assert_eq!("SELECT\tNULL", replace_params(params, stmt, query).unwrap());
1083
1084 let query = "SELECT CAST(? AS INT64), ? + (SELECT ?)".to_string();
1085 let stmt = statement_with_transformed_placeholders(&query);
1086 let params = vec!["1".to_string(), "2".to_string(), "3".to_string()];
1087
1088 assert_eq!(
1089 "SELECT CAST(1 AS INT64), 2 + (SELECT 3)",
1090 replace_params(params, stmt, query).unwrap()
1091 );
1092
1093 let query = "SET time_zone = ?".to_string();
1094 let stmt = statement_with_transformed_placeholders(&query);
1095 let params = vec!["'UTC'".to_string()];
1096
1097 assert_eq!(
1098 "SET time_zone = 'UTC'",
1099 replace_params(params, stmt, query).unwrap()
1100 );
1101 }
1102
1103 #[tokio::test]
1104 async fn test_prepare_federated_query() {
1105 let mut shim = create_shim();
1106 let query_ctx = QueryContext::arc();
1107 let stmt_key = "test_federated".to_string();
1108
1109 let (params, columns) = shim
1110 .do_prepare(
1111 "SELECT @@version_comment",
1112 query_ctx.clone(),
1113 stmt_key.clone(),
1114 )
1115 .await
1116 .unwrap();
1117
1118 assert!(params.is_empty());
1119 assert!(columns.is_empty());
1120
1121 let plan = shim.plan(&stmt_key).unwrap();
1122 assert!(matches!(plan, SqlPlan::Shortcut(q) if q == "SELECT @@version_comment"));
1123 }
1124
1125 #[tokio::test]
1126 async fn test_execute_federated_shortcut() {
1127 let mut shim = create_shim();
1128 let query_ctx = QueryContext::arc();
1129 let stmt_key = "test_federated_exec".to_string();
1130
1131 shim.do_prepare(
1132 "SELECT @@version_comment",
1133 query_ctx.clone(),
1134 stmt_key.clone(),
1135 )
1136 .await
1137 .unwrap();
1138
1139 let outputs = shim
1140 .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
1141 .await
1142 .unwrap();
1143
1144 assert_eq!(outputs.len(), 1);
1145 let output = outputs.into_iter().next().unwrap().unwrap();
1146 let pretty = output.data.pretty_print().await;
1147 assert!(pretty.contains("GreptimeDB"));
1148 }
1149
1150 #[tokio::test]
1151 async fn test_prepare_non_federated_query_not_shortcut() {
1152 let mut shim = create_shim();
1153 let query_ctx = QueryContext::arc();
1154 let stmt_key = "test_non_federated".to_string();
1155
1156 let result = shim
1157 .do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
1158 .await;
1159
1160 assert!(result.is_ok());
1161 let plan = shim.plan(&stmt_key).unwrap();
1162 assert!(matches!(plan, SqlPlan::Shortcut(_)));
1163 }
1164
1165 #[tokio::test]
1166 async fn test_execute_set_shortcut() {
1167 let mut shim = create_shim();
1168 let query_ctx = QueryContext::arc();
1169 let stmt_key = "test_set_shortcut".to_string();
1170
1171 shim.do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
1172 .await
1173 .unwrap();
1174
1175 let outputs = shim
1176 .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
1177 .await
1178 .unwrap();
1179
1180 assert_eq!(outputs.len(), 1);
1181 let output = outputs.into_iter().next().unwrap().unwrap();
1182 match output.data {
1183 common_query::OutputData::RecordBatches(batches) => {
1184 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
1185 assert_eq!(total_rows, 0);
1186 }
1187 other => panic!("Expected RecordBatches, got {:?}", other),
1188 }
1189 }
1190}