Skip to main content

servers/mysql/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use itertools::Itertools;
32use opensrv_mysql::{
33    AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
34    StatementMetaWriter, ValueInner,
35};
36use parking_lot::RwLock;
37use query::planner::DfLogicalPlanner;
38use query::query_engine::DescribeResult;
39use rand::RngCore;
40use session::context::{Channel, QueryContextRef};
41use session::{Session, SessionRef};
42use snafu::{ResultExt, ensure};
43use sql::dialect::MySqlDialect;
44use sql::parser::{ParseOptions, ParserContext};
45use sql::statements::statement::Statement;
46use tokio::io::AsyncWrite;
47
48use crate::SqlPlan;
49use crate::error::{
50    self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
51};
52use crate::metrics::METRIC_AUTH_FAILURE;
53use crate::mysql::helper::{
54    self, format_placeholder, replace_placeholders, transform_placeholders,
55};
56use crate::mysql::writer;
57use crate::mysql::writer::{create_mysql_column, handle_err};
58use crate::query_handler::sql::ServerSqlQueryHandlerRef;
59
60const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
61const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
62
63/// Parameters for the prepared statement
64enum Params<'a> {
65    /// Parameters passed through protocol
66    ProtocolParams(Vec<ParamValue<'a>>),
67    /// Parameters passed through cli
68    CliParams(Vec<sql::ast::Expr>),
69}
70
71impl Params<'_> {
72    fn len(&self) -> usize {
73        match self {
74            Params::ProtocolParams(params) => params.len(),
75            Params::CliParams(params) => params.len(),
76        }
77    }
78}
79
80// An intermediate shim for executing MySQL queries.
81pub struct MysqlInstanceShim {
82    query_handler: ServerSqlQueryHandlerRef,
83    salt: [u8; 20],
84    session: SessionRef,
85    user_provider: Option<UserProviderRef>,
86    prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
87    prepared_stmts_counter: AtomicU32,
88    process_id: u32,
89    prepared_stmt_cache_size: usize,
90}
91
92impl MysqlInstanceShim {
93    pub fn create(
94        query_handler: ServerSqlQueryHandlerRef,
95        user_provider: Option<UserProviderRef>,
96        client_addr: SocketAddr,
97        process_id: u32,
98        prepared_stmt_cache_size: usize,
99    ) -> MysqlInstanceShim {
100        // init a random salt
101        let mut bs = vec![0u8; 20];
102        let mut rng = rand::rng();
103        rng.fill_bytes(bs.as_mut());
104
105        let mut scramble: [u8; 20] = [0; 20];
106        for i in 0..20 {
107            scramble[i] = bs[i] & 0x7fu8;
108            if scramble[i] == b'\0' || scramble[i] == b'$' {
109                scramble[i] += 1;
110            }
111        }
112
113        MysqlInstanceShim {
114            query_handler,
115            salt: scramble,
116            session: Arc::new(Session::new(
117                Some(client_addr),
118                Channel::Mysql,
119                Default::default(),
120                process_id,
121            )),
122            user_provider,
123            prepared_stmts: Default::default(),
124            prepared_stmts_counter: AtomicU32::new(1),
125            process_id,
126            prepared_stmt_cache_size,
127        }
128    }
129
130    #[tracing::instrument(skip_all, name = "mysql::do_query")]
131    async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
132        if let Some(output) =
133            crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
134        {
135            vec![Ok(output)]
136        } else {
137            self.query_handler.do_query(query, query_ctx.clone()).await
138        }
139    }
140
141    /// Execute the logical plan and return the output
142    async fn do_exec_plan(
143        &self,
144        query: &str,
145        stmt: Option<Statement>,
146        plan: LogicalPlan,
147        query_ctx: QueryContextRef,
148    ) -> Result<Output> {
149        if let Some(output) =
150            crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
151        {
152            Ok(output)
153        } else {
154            self.query_handler.do_exec_plan(stmt, plan, query_ctx).await
155        }
156    }
157
158    /// Describe the statement
159    async fn do_describe(
160        &self,
161        statement: Statement,
162        query_ctx: QueryContextRef,
163    ) -> Result<Option<DescribeResult>> {
164        self.query_handler.do_describe(statement, query_ctx).await
165    }
166
167    /// Save query and logical plan with a given statement key
168    fn save_plan(&self, plan: SqlPlan, stmt_key: String) -> Result<()> {
169        let mut prepared_stmts = self.prepared_stmts.write();
170        let max_capacity = self.prepared_stmt_cache_size;
171
172        let is_update = prepared_stmts.contains_key(&stmt_key);
173
174        if !is_update && prepared_stmts.len() >= max_capacity {
175            return error::InternalSnafu {
176                err_msg: format!(
177                    "Prepared statement cache is full, max capacity: {}",
178                    max_capacity
179                ),
180            }
181            .fail();
182        }
183
184        let _ = prepared_stmts.insert(stmt_key, plan);
185        Ok(())
186    }
187
188    /// Retrieve the query and logical plan by a given statement key
189    fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
190        let guard = self.prepared_stmts.read();
191        guard.get(stmt_key).cloned()
192    }
193
194    /// Save the prepared statement and return the parameters and result columns
195    async fn do_prepare(
196        &mut self,
197        raw_query: &str,
198        query_ctx: QueryContextRef,
199        stmt_key: String,
200    ) -> Result<(Vec<Column>, Vec<Column>)> {
201        let (query, param_num) = replace_placeholders(raw_query);
202
203        let statement = validate_query(raw_query).await?;
204
205        // We have to transform the placeholder, because DataFusion only parses placeholders
206        // in the form of "$i", it can't process "?" right now.
207        let statement = transform_placeholders(statement);
208
209        let describe_result = self
210            .do_describe(statement.clone(), query_ctx.clone())
211            .await?;
212        let (plan, schema) = if let Some(DescribeResult {
213            logical_plan,
214            schema,
215        }) = describe_result
216        {
217            (Some(logical_plan), Some(schema))
218        } else {
219            (None, None)
220        };
221
222        let params = if let Some(plan) = &plan {
223            let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
224                .context(InferParameterTypesSnafu)?
225                .into_iter()
226                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
227                .collect();
228            prepared_params(&param_types)?
229        } else {
230            dummy_params(param_num)?
231        };
232
233        let columns = schema
234            .as_ref()
235            .map(|schema| {
236                schema
237                    .column_schemas()
238                    .iter()
239                    .map(|column_schema| {
240                        create_mysql_column(&column_schema.data_type, &column_schema.name)
241                    })
242                    .collect::<Result<Vec<_>>>()
243            })
244            .transpose()?
245            .unwrap_or_default();
246
247        // DataFusion may optimize the plan so that some parameters are not used.
248        if params.len() != param_num - 1 {
249            self.save_plan(
250                SqlPlan {
251                    query: query.clone(),
252                    statement: Some(statement),
253                    plan: None,
254                    schema: None,
255                },
256                stmt_key,
257            )
258            .map_err(|e| {
259                error!(e; "Failed to save prepared statement");
260                e
261            })?;
262        } else {
263            self.save_plan(
264                SqlPlan {
265                    query: query.clone(),
266                    statement: Some(statement),
267                    plan,
268                    schema,
269                },
270                stmt_key,
271            )
272            .map_err(|e| {
273                error!(e; "Failed to save prepared statement");
274                e
275            })?;
276        }
277
278        Ok((params, columns))
279    }
280
281    async fn do_execute(
282        &mut self,
283        query_ctx: QueryContextRef,
284        stmt_key: String,
285        params: Params<'_>,
286    ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
287        let sql_plan = match self.plan(&stmt_key) {
288            None => {
289                return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
290            }
291            Some(sql_plan) => sql_plan,
292        };
293
294        let outputs = match sql_plan.plan {
295            Some(plan) => {
296                let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
297                    .context(InferParameterTypesSnafu)?
298                    .into_iter()
299                    .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
300                    .collect::<HashMap<_, _>>();
301
302                if params.len() != param_types.len() {
303                    return error::InternalSnafu {
304                        err_msg: "Prepare statement params number mismatch".to_string(),
305                    }
306                    .fail();
307                }
308
309                let plan = match params {
310                    Params::ProtocolParams(params) => {
311                        replace_params_with_values(&plan, param_types, &params)
312                    }
313                    Params::CliParams(params) => {
314                        replace_params_with_exprs(&plan, param_types, &params)
315                    }
316                }?;
317
318                debug!("Mysql execute prepared plan: {}", plan.display_indent());
319                vec![
320                    self.do_exec_plan(
321                        &sql_plan.query,
322                        sql_plan.statement.clone(),
323                        plan,
324                        query_ctx.clone(),
325                    )
326                    .await,
327                ]
328            }
329            None => {
330                let param_strs = match params {
331                    Params::ProtocolParams(params) => {
332                        params.iter().map(convert_param_value_to_string).collect()
333                    }
334                    Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
335                };
336                debug!(
337                    "do_execute Replacing with Params: {:?}, Original Query: {}",
338                    param_strs, sql_plan.query
339                );
340                let query = replace_params(param_strs, sql_plan.query);
341                debug!("Mysql execute replaced query: {}", query);
342                self.do_query(&query, query_ctx.clone()).await
343            }
344        };
345
346        Ok(outputs)
347    }
348
349    /// Remove the prepared statement by a given statement key
350    fn do_close(&mut self, stmt_key: String) {
351        let mut guard = self.prepared_stmts.write();
352        let _ = guard.remove(&stmt_key);
353    }
354
355    fn auth_plugin(&self) -> &'static str {
356        if self
357            .user_provider
358            .as_ref()
359            .map(|x| x.external())
360            .unwrap_or(false)
361        {
362            MYSQL_CLEAR_PASSWORD
363        } else {
364            MYSQL_NATIVE_PASSWORD
365        }
366    }
367}
368
369#[async_trait]
370impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
371    type Error = error::Error;
372
373    fn version(&self) -> String {
374        std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
375    }
376
377    fn connect_id(&self) -> u32 {
378        self.process_id
379    }
380
381    fn default_auth_plugin(&self) -> &str {
382        self.auth_plugin()
383    }
384
385    async fn auth_plugin_for_username(&self, _user: &[u8]) -> &'static str {
386        self.auth_plugin()
387    }
388
389    fn salt(&self) -> [u8; 20] {
390        self.salt
391    }
392
393    async fn authenticate(
394        &self,
395        auth_plugin: &str,
396        username: &[u8],
397        salt: &[u8],
398        auth_data: &[u8],
399    ) -> bool {
400        // if not specified then **greptime** will be used
401        let username = String::from_utf8_lossy(username);
402
403        let mut user_info = None;
404        let addr = self
405            .session
406            .conn_info()
407            .client_addr
408            .map(|addr| addr.to_string());
409        if let Some(user_provider) = &self.user_provider {
410            let user_id = Identity::UserId(&username, addr.as_deref());
411
412            let password = match auth_plugin {
413                MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
414                MYSQL_CLEAR_PASSWORD => {
415                    // The raw bytes received could be represented in C-like string, ended in '\0'.
416                    // We must "trim" it to get the real password string.
417                    let password = if let &[password @ .., 0] = &auth_data {
418                        password
419                    } else {
420                        auth_data
421                    };
422                    Password::PlainText(String::from_utf8_lossy(password).to_string().into())
423                }
424                other => {
425                    error!("Unsupported mysql auth plugin: {}", other);
426                    return false;
427                }
428            };
429            match user_provider.authenticate(user_id, password).await {
430                Ok(userinfo) => {
431                    user_info = Some(userinfo);
432                }
433                Err(e) => {
434                    METRIC_AUTH_FAILURE
435                        .with_label_values(&[e.status_code().as_ref()])
436                        .inc();
437                    warn!(e; "Failed to auth");
438                    return false;
439                }
440            };
441        }
442        let user_info =
443            user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
444
445        self.session.set_user_info(user_info);
446
447        true
448    }
449
450    async fn on_prepare<'a>(
451        &'a mut self,
452        raw_query: &'a str,
453        w: StatementMetaWriter<'a, W>,
454    ) -> Result<()> {
455        let query_ctx = self.session.new_query_context();
456        let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
457        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
458        let (params, columns) = match self
459            .do_prepare(raw_query, query_ctx.clone(), stmt_key)
460            .await
461        {
462            Ok(x) => x,
463            Err(e) => {
464                let (kind, msg) = handle_err(e, query_ctx.clone());
465                w.error(kind, msg.as_bytes()).await?;
466                return Ok(());
467            }
468        };
469        debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
470        w.reply(stmt_id, &params, &columns).await?;
471        crate::metrics::METRIC_MYSQL_PREPARED_COUNT
472            .with_label_values(&[query_ctx.get_db_string().as_str()])
473            .inc();
474        return Ok(());
475    }
476
477    async fn on_execute<'a>(
478        &'a mut self,
479        stmt_id: u32,
480        p: ParamParser<'a>,
481        w: QueryResultWriter<'a, W>,
482    ) -> Result<()> {
483        self.session.clear_warnings();
484
485        let query_ctx = self.session.new_query_context();
486        let db = query_ctx.get_db_string();
487        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
488            .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
489            .start_timer();
490
491        let params: Vec<ParamValue> = p.into_iter().collect();
492        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
493
494        let outputs = match self
495            .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
496            .await
497        {
498            Ok(outputs) => outputs,
499            Err(e) => {
500                let (kind, err) = handle_err(e, query_ctx);
501                debug!(
502                    "Failed to execute prepared statement, kind: {:?}, err: {}",
503                    kind, err
504                );
505                w.error(kind, err.as_bytes()).await?;
506                return Ok(());
507            }
508        };
509
510        writer::write_output(w, query_ctx, self.session.clone(), outputs).await?;
511
512        Ok(())
513    }
514
515    async fn on_close<'a>(&'a mut self, stmt_id: u32)
516    where
517        W: 'async_trait,
518    {
519        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
520        self.do_close(stmt_key);
521    }
522
523    #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
524    async fn on_query<'a>(
525        &'a mut self,
526        query: &'a str,
527        writer: QueryResultWriter<'a, W>,
528    ) -> Result<()> {
529        let query_ctx = self.session.new_query_context();
530        let db = query_ctx.get_db_string();
531        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
532            .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
533            .start_timer();
534
535        // Clear warnings for non SHOW WARNINGS queries
536        let query_upcase = query.to_uppercase();
537        if !query_upcase.starts_with("SHOW WARNINGS") {
538            self.session.clear_warnings();
539        }
540
541        if query_upcase.starts_with("PREPARE ") {
542            match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
543                Ok((stmt_name, stmt)) => {
544                    let prepare_results =
545                        self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
546                    match prepare_results {
547                        Ok(_) => {
548                            let outputs = vec![Ok(Output::new_with_affected_rows(0))];
549                            writer::write_output(writer, query_ctx, self.session.clone(), outputs)
550                                .await?;
551                            return Ok(());
552                        }
553                        Err(e) => {
554                            writer
555                                .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
556                                .await?;
557                            return Ok(());
558                        }
559                    }
560                }
561                Err(e) => {
562                    writer
563                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
564                        .await?;
565                    return Ok(());
566                }
567            }
568        } else if query_upcase.starts_with("EXECUTE ") {
569            match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
570                Ok((stmt_name, params)) => {
571                    let outputs = match self
572                        .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
573                        .await
574                    {
575                        Ok(outputs) => outputs,
576                        Err(e) => {
577                            let (kind, err) = handle_err(e, query_ctx);
578                            debug!(
579                                "Failed to execute prepared statement, kind: {:?}, err: {}",
580                                kind, err
581                            );
582                            writer.error(kind, err.as_bytes()).await?;
583                            return Ok(());
584                        }
585                    };
586                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
587
588                    return Ok(());
589                }
590                Err(e) => {
591                    writer
592                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
593                        .await?;
594                    return Ok(());
595                }
596            }
597        } else if query_upcase.starts_with("DEALLOCATE ") {
598            match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
599                Ok(stmt_name) => {
600                    self.do_close(stmt_name);
601                    let outputs = vec![Ok(Output::new_with_affected_rows(0))];
602                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
603                    return Ok(());
604                }
605                Err(e) => {
606                    writer
607                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
608                        .await?;
609                    return Ok(());
610                }
611            }
612        }
613
614        let outputs = self.do_query(query, query_ctx.clone()).await;
615        writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
616
617        Ok(())
618    }
619
620    async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
621        let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
622        let catalog = if let Some(catalog) = &catalog_from_db {
623            catalog.clone()
624        } else {
625            self.session.catalog()
626        };
627
628        if !self
629            .query_handler
630            .is_valid_schema(&catalog, &schema)
631            .await?
632        {
633            return w
634                .error(
635                    ErrorKind::ER_WRONG_DB_NAME,
636                    format!("Unknown database '{}'", database).as_bytes(),
637                )
638                .await
639                .map_err(|e| e.into());
640        }
641
642        let user_info = &self.session.user_info();
643
644        if let Some(schema_validator) = &self.user_provider
645            && let Err(e) = schema_validator
646                .authorize(&catalog, &schema, user_info)
647                .await
648        {
649            METRIC_AUTH_FAILURE
650                .with_label_values(&[e.status_code().as_ref()])
651                .inc();
652            return w
653                .error(
654                    ErrorKind::ER_DBACCESS_DENIED_ERROR,
655                    e.output_msg().as_bytes(),
656                )
657                .await
658                .map_err(|e| e.into());
659        }
660
661        if catalog_from_db.is_some() {
662            self.session.set_catalog(catalog)
663        }
664        self.session.set_schema(schema);
665
666        w.ok().await.map_err(|e| e.into())
667    }
668}
669
670fn convert_param_value_to_string(param: &ParamValue) -> String {
671    match param.value.into_inner() {
672        ValueInner::Int(u) => u.to_string(),
673        ValueInner::UInt(u) => u.to_string(),
674        ValueInner::Double(u) => u.to_string(),
675        ValueInner::NULL => "NULL".to_string(),
676        ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
677        ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
678        ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
679        ValueInner::Time(_) => format_duration(Duration::from(param.value)),
680    }
681}
682
683fn replace_params(params: Vec<String>, query: String) -> String {
684    let mut query = query;
685    for (index, param) in (1..).zip(params) {
686        query = query.replace(&format_placeholder(index), &param);
687    }
688    query
689}
690
691fn format_duration(duration: Duration) -> String {
692    let seconds = duration.as_secs() % 60;
693    let minutes = (duration.as_secs() / 60) % 60;
694    let hours = (duration.as_secs() / 60) / 60;
695    format!("'{}:{}:{}'", hours, minutes, seconds)
696}
697
698fn replace_params_with_values(
699    plan: &LogicalPlan,
700    param_types: HashMap<String, Option<ConcreteDataType>>,
701    params: &[ParamValue],
702) -> Result<LogicalPlan> {
703    debug_assert_eq!(param_types.len(), params.len());
704
705    debug!(
706        "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
707        param_types,
708        params
709            .iter()
710            .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
711            .join(", "),
712        plan
713    );
714
715    let mut values = Vec::with_capacity(params.len());
716
717    for (i, param) in params.iter().enumerate() {
718        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
719            let value = helper::convert_value(param, t)?;
720
721            values.push(value.into());
722        }
723    }
724
725    plan.clone()
726        .replace_params_with_values(&ParamValues::List(values.clone()))
727        .context(DataFrameSnafu)
728}
729
730fn replace_params_with_exprs(
731    plan: &LogicalPlan,
732    param_types: HashMap<String, Option<ConcreteDataType>>,
733    params: &[sql::ast::Expr],
734) -> Result<LogicalPlan> {
735    debug_assert_eq!(param_types.len(), params.len());
736
737    debug!(
738        "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
739        param_types,
740        params.iter().map(|x| format!("({:?})", x)).join(", "),
741        plan
742    );
743
744    let mut values = Vec::with_capacity(params.len());
745
746    for (i, param) in params.iter().enumerate() {
747        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
748            let value = helper::convert_expr_to_scalar_value(param, t)?;
749
750            values.push(value.into());
751        }
752    }
753
754    plan.clone()
755        .replace_params_with_values(&ParamValues::List(values.clone()))
756        .context(DataFrameSnafu)
757}
758
759async fn validate_query(query: &str) -> Result<Statement> {
760    let statement =
761        ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
762    let mut statement = statement.map_err(|e| {
763        InvalidPrepareStatementSnafu {
764            err_msg: e.output_msg(),
765        }
766        .build()
767    })?;
768
769    ensure!(
770        statement.len() == 1,
771        InvalidPrepareStatementSnafu {
772            err_msg: "prepare statement only support single statement".to_string(),
773        }
774    );
775
776    let statement = statement.remove(0);
777
778    Ok(statement)
779}
780
781fn dummy_params(index: usize) -> Result<Vec<Column>> {
782    let mut params = Vec::with_capacity(index - 1);
783
784    for _ in 1..index {
785        params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
786    }
787
788    Ok(params)
789}
790
791/// Parameters that the client must provide when executing the prepared statement.
792fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
793    let mut params = Vec::with_capacity(param_types.len());
794
795    // Placeholder index starts from 1
796    for index in 1..=param_types.len() {
797        if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
798            let column = create_mysql_column(t, "")?;
799            params.push(column);
800        }
801    }
802
803    Ok(params)
804}