Skip to main content

servers/mysql/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use datatypes::schema::Schema;
32use itertools::Itertools;
33use mysql_common::Value as MysqlValue;
34use opensrv_mysql::{
35    AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
36    StatementMetaWriter, ValueInner,
37};
38use parking_lot::RwLock;
39use query::planner::DfLogicalPlanner;
40use query::query_engine::DescribeResult;
41use rand::RngCore;
42use session::context::{Channel, QueryContextRef};
43use session::{Session, SessionRef};
44use snafu::{ResultExt, ensure};
45use sql::dialect::MySqlDialect;
46use sql::parser::{ParseOptions, ParserContext};
47use sql::statements::statement::Statement;
48use tokio::io::AsyncWrite;
49
50use crate::SqlPlan;
51use crate::error::{
52    self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
53};
54use crate::metrics::METRIC_AUTH_FAILURE;
55use crate::mysql::helper::{self, format_placeholder, transform_placeholders_with_count};
56use crate::mysql::writer;
57use crate::mysql::writer::{create_mysql_column, handle_err};
58use crate::query_handler::sql::ServerSqlQueryHandlerRef;
59
60const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
61const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
62
63/// Parameters for the prepared statement
64enum Params<'a> {
65    /// Parameters passed through protocol
66    ProtocolParams(Vec<ParamValue<'a>>),
67    /// Parameters passed through cli
68    CliParams(Vec<sql::ast::Expr>),
69}
70
71impl Params<'_> {
72    fn len(&self) -> usize {
73        match self {
74            Params::ProtocolParams(params) => params.len(),
75            Params::CliParams(params) => params.len(),
76        }
77    }
78}
79
80// An intermediate shim for executing MySQL queries.
81pub struct MysqlInstanceShim {
82    query_handler: ServerSqlQueryHandlerRef,
83    salt: [u8; 20],
84    session: SessionRef,
85    user_provider: Option<UserProviderRef>,
86    prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
87    prepared_stmts_counter: AtomicU32,
88    process_id: u32,
89    prepared_stmt_cache_size: usize,
90}
91
92impl MysqlInstanceShim {
93    pub fn create(
94        query_handler: ServerSqlQueryHandlerRef,
95        user_provider: Option<UserProviderRef>,
96        client_addr: SocketAddr,
97        process_id: u32,
98        prepared_stmt_cache_size: usize,
99    ) -> MysqlInstanceShim {
100        // init a random salt
101        let mut bs = vec![0u8; 20];
102        let mut rng = rand::rng();
103        rng.fill_bytes(bs.as_mut());
104
105        let mut scramble: [u8; 20] = [0; 20];
106        for i in 0..20 {
107            scramble[i] = bs[i] & 0x7fu8;
108            if scramble[i] == b'\0' || scramble[i] == b'$' {
109                scramble[i] += 1;
110            }
111        }
112
113        MysqlInstanceShim {
114            query_handler,
115            salt: scramble,
116            session: Arc::new(Session::new(
117                Some(client_addr),
118                Channel::Mysql,
119                Default::default(),
120                process_id,
121            )),
122            user_provider,
123            prepared_stmts: Default::default(),
124            prepared_stmts_counter: AtomicU32::new(1),
125            process_id,
126            prepared_stmt_cache_size,
127        }
128    }
129
130    #[tracing::instrument(skip_all, name = "mysql::do_query")]
131    async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
132        if let Some(output) =
133            crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
134        {
135            vec![Ok(output)]
136        } else {
137            self.query_handler.do_query(query, query_ctx.clone()).await
138        }
139    }
140
141    /// Describe the statement
142    async fn do_describe(
143        &self,
144        statement: Statement,
145        query_ctx: QueryContextRef,
146    ) -> Result<Option<DescribeResult>> {
147        self.query_handler.do_describe(statement, query_ctx).await
148    }
149
150    /// Save query and logical plan with a given statement key
151    fn save_plan(&self, plan: SqlPlan, stmt_key: String) -> Result<()> {
152        let mut prepared_stmts = self.prepared_stmts.write();
153        let max_capacity = self.prepared_stmt_cache_size;
154
155        let is_update = prepared_stmts.contains_key(&stmt_key);
156
157        if !is_update && prepared_stmts.len() >= max_capacity {
158            return error::InternalSnafu {
159                err_msg: format!(
160                    "Prepared statement cache is full, max capacity: {}",
161                    max_capacity
162                ),
163            }
164            .fail();
165        }
166
167        let _ = prepared_stmts.insert(stmt_key, plan);
168        Ok(())
169    }
170
171    /// Retrieve the query and logical plan by a given statement key
172    fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
173        let guard = self.prepared_stmts.read();
174        guard.get(stmt_key).cloned()
175    }
176
177    /// Save the prepared statement and return the parameters and result columns
178    async fn do_prepare(
179        &mut self,
180        raw_query: &str,
181        query_ctx: QueryContextRef,
182        stmt_key: String,
183    ) -> Result<(Vec<Column>, Vec<Column>)> {
184        if crate::mysql::federated::check(raw_query, query_ctx.clone(), self.session.clone())
185            .is_some()
186        {
187            self.save_plan(SqlPlan::Shortcut(raw_query.to_string()), stmt_key)
188                .inspect_err(|e| {
189                    error!(e; "Failed to save prepared statement");
190                })?;
191            return Ok((vec![], vec![]));
192        }
193
194        let statement = validate_query(raw_query).await?;
195
196        // We have to transform the placeholder, because DataFusion only parses placeholders
197        // in the form of "$i", it can't process "?" right now.
198        let (statement, placeholder_count) = transform_placeholders_with_count(statement);
199        let param_num = placeholder_count + 1;
200
201        let describe_result = self
202            .do_describe(statement.clone(), query_ctx.clone())
203            .await?;
204        let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan);
205
206        let (params, can_cache_as_plan) = if let Some(plan) = &plan {
207            let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
208                .context(InferParameterTypesSnafu)?
209                .into_iter()
210                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
211                .collect();
212
213            (
214                prepared_params(&param_types, param_num)?,
215                all_params_have_types(&param_types, param_num),
216            )
217        } else {
218            (dummy_params(param_num)?, false)
219        };
220
221        let columns =
222            plan.as_ref()
223                .map(|plan| {
224                    let schema: Schema = plan.schema().clone().try_into().map_err(
225                        |e: datatypes::error::Error| {
226                            error::InternalSnafu {
227                                err_msg: e.to_string(),
228                            }
229                            .build()
230                        },
231                    )?;
232                    schema
233                        .column_schemas()
234                        .iter()
235                        .map(|column_schema| {
236                            create_mysql_column(&column_schema.data_type, &column_schema.name)
237                        })
238                        .collect::<Result<Vec<_>>>()
239                })
240                .transpose()?
241                .unwrap_or_default();
242
243        match plan {
244            Some(plan) if can_cache_as_plan => {
245                self.save_plan(SqlPlan::Plan(plan, statement), stmt_key)
246                    .inspect_err(|e| {
247                        error!(e; "Failed to save prepared statement");
248                    })?;
249            }
250            _ => {
251                self.save_plan(
252                    SqlPlan::Statement(statement, raw_query.to_string()),
253                    stmt_key,
254                )
255                .inspect_err(|e| {
256                    error!(e; "Failed to save prepared statement");
257                })?;
258            }
259        }
260
261        Ok((params, columns))
262    }
263
264    async fn do_execute(
265        &mut self,
266        query_ctx: QueryContextRef,
267        stmt_key: String,
268        params: Params<'_>,
269    ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
270        let sql_plan = match self.plan(&stmt_key) {
271            None => {
272                return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
273            }
274            Some(sql_plan) => sql_plan,
275        };
276
277        let outputs = match sql_plan {
278            SqlPlan::Plan(plan, stmt) => {
279                let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
280                    .context(InferParameterTypesSnafu)?
281                    .into_iter()
282                    .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
283                    .collect::<HashMap<_, _>>();
284
285                if params.len() != param_types.len() {
286                    return error::InternalSnafu {
287                        err_msg: "Prepare statement params number mismatch".to_string(),
288                    }
289                    .fail();
290                }
291
292                let replaced_plan = match params {
293                    Params::ProtocolParams(params) => {
294                        replace_params_with_values(&plan, param_types, &params)
295                    }
296                    Params::CliParams(params) => {
297                        replace_params_with_exprs(&plan, param_types, &params)
298                    }
299                }?;
300
301                debug!(
302                    "Mysql execute prepared plan: {}",
303                    replaced_plan.display_indent()
304                );
305                vec![
306                    self.query_handler
307                        .do_exec_plan(replaced_plan, Some(stmt), query_ctx.clone())
308                        .await,
309                ]
310            }
311            SqlPlan::Shortcut(query) => {
312                if let Some(output) =
313                    crate::mysql::federated::check(&query, query_ctx.clone(), self.session.clone())
314                {
315                    vec![Ok(output)]
316                } else {
317                    self.do_query(&query, query_ctx.clone()).await
318                }
319            }
320            SqlPlan::Statement(stmt, query) => {
321                let param_strs = match params {
322                    Params::ProtocolParams(params) => {
323                        params.iter().map(convert_param_value_to_string).collect()
324                    }
325                    Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
326                };
327                debug!(
328                    "do_execute Replacing with Params: {:?}, Original Query: {}",
329                    param_strs, query
330                );
331                let query = replace_params(param_strs, stmt, query)?;
332                debug!("Mysql execute replaced query: {}", query);
333                self.do_query(&query, query_ctx.clone()).await
334            }
335            _ => {
336                return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
337            }
338        };
339
340        Ok(outputs)
341    }
342
343    /// Remove the prepared statement by a given statement key
344    fn do_close(&mut self, stmt_key: String) {
345        let mut guard = self.prepared_stmts.write();
346        let _ = guard.remove(&stmt_key);
347    }
348
349    fn auth_plugin(&self) -> &'static str {
350        if self
351            .user_provider
352            .as_ref()
353            .map(|x| x.external())
354            .unwrap_or(false)
355        {
356            MYSQL_CLEAR_PASSWORD
357        } else {
358            MYSQL_NATIVE_PASSWORD
359        }
360    }
361}
362
363#[async_trait]
364impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
365    type Error = error::Error;
366
367    fn version(&self) -> String {
368        std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
369    }
370
371    fn connect_id(&self) -> u32 {
372        self.process_id
373    }
374
375    fn default_auth_plugin(&self) -> &str {
376        self.auth_plugin()
377    }
378
379    async fn auth_plugin_for_username(&self, _user: &[u8]) -> &'static str {
380        self.auth_plugin()
381    }
382
383    fn salt(&self) -> [u8; 20] {
384        self.salt
385    }
386
387    async fn authenticate(
388        &self,
389        auth_plugin: &str,
390        username: &[u8],
391        salt: &[u8],
392        auth_data: &[u8],
393    ) -> bool {
394        // if not specified then **greptime** will be used
395        let username = String::from_utf8_lossy(username);
396
397        let mut user_info = None;
398        let addr = self
399            .session
400            .conn_info()
401            .client_addr
402            .map(|addr| addr.to_string());
403        if let Some(user_provider) = &self.user_provider {
404            let user_id = Identity::UserId(&username, addr.as_deref());
405
406            let password = match auth_plugin {
407                MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
408                MYSQL_CLEAR_PASSWORD => {
409                    // The raw bytes received could be represented in C-like string, ended in '\0'.
410                    // We must "trim" it to get the real password string.
411                    let password = if let &[password @ .., 0] = &auth_data {
412                        password
413                    } else {
414                        auth_data
415                    };
416                    Password::PlainText(String::from_utf8_lossy(password).to_string().into())
417                }
418                other => {
419                    error!("Unsupported mysql auth plugin: {}", other);
420                    return false;
421                }
422            };
423            match user_provider.authenticate(user_id, password).await {
424                Ok(userinfo) => {
425                    user_info = Some(userinfo);
426                }
427                Err(e) => {
428                    METRIC_AUTH_FAILURE
429                        .with_label_values(&[e.status_code().as_ref()])
430                        .inc();
431                    warn!(e; "Failed to auth");
432                    return false;
433                }
434            };
435        }
436        let user_info =
437            user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
438
439        self.session.set_user_info(user_info);
440
441        true
442    }
443
444    async fn on_prepare<'a>(
445        &'a mut self,
446        raw_query: &'a str,
447        w: StatementMetaWriter<'a, W>,
448    ) -> Result<()> {
449        let query_ctx = self.session.new_query_context();
450        let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
451        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
452        let (params, columns) = match self
453            .do_prepare(raw_query, query_ctx.clone(), stmt_key)
454            .await
455        {
456            Ok(x) => x,
457            Err(e) => {
458                let (kind, msg) = handle_err(e, query_ctx.clone());
459                w.error(kind, msg.as_bytes()).await?;
460                return Ok(());
461            }
462        };
463        debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
464        w.reply(stmt_id, &params, &columns).await?;
465        crate::metrics::METRIC_MYSQL_PREPARED_COUNT
466            .with_label_values(&[query_ctx.get_db_string().as_str()])
467            .inc();
468        return Ok(());
469    }
470
471    async fn on_execute<'a>(
472        &'a mut self,
473        stmt_id: u32,
474        p: ParamParser<'a>,
475        w: QueryResultWriter<'a, W>,
476    ) -> Result<()> {
477        self.session.clear_warnings();
478
479        let query_ctx = self.session.new_query_context();
480        let db = query_ctx.get_db_string();
481        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
482            .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
483            .start_timer();
484
485        let params: Vec<ParamValue> = p.into_iter().collect();
486        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
487
488        let outputs = match self
489            .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
490            .await
491        {
492            Ok(outputs) => outputs,
493            Err(e) => {
494                let (kind, err) = handle_err(e, query_ctx);
495                debug!(
496                    "Failed to execute prepared statement, kind: {:?}, err: {}",
497                    kind, err
498                );
499                w.error(kind, err.as_bytes()).await?;
500                return Ok(());
501            }
502        };
503
504        writer::write_output(w, query_ctx, self.session.clone(), outputs).await?;
505
506        Ok(())
507    }
508
509    async fn on_close<'a>(&'a mut self, stmt_id: u32)
510    where
511        W: 'async_trait,
512    {
513        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
514        self.do_close(stmt_key);
515    }
516
517    #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
518    async fn on_query<'a>(
519        &'a mut self,
520        query: &'a str,
521        writer: QueryResultWriter<'a, W>,
522    ) -> Result<()> {
523        let query_ctx = self.session.new_query_context();
524        let db = query_ctx.get_db_string();
525        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
526            .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
527            .start_timer();
528
529        // Clear warnings for non SHOW WARNINGS queries
530        let query_upcase = query.to_uppercase();
531        if !query_upcase.starts_with("SHOW WARNINGS") {
532            self.session.clear_warnings();
533        }
534
535        if query_upcase.starts_with("PREPARE ") {
536            match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
537                Ok((stmt_name, stmt)) => {
538                    let prepare_results =
539                        self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
540                    match prepare_results {
541                        Ok(_) => {
542                            let outputs = vec![Ok(Output::new_with_affected_rows(0))];
543                            writer::write_output(writer, query_ctx, self.session.clone(), outputs)
544                                .await?;
545                            return Ok(());
546                        }
547                        Err(e) => {
548                            writer
549                                .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
550                                .await?;
551                            return Ok(());
552                        }
553                    }
554                }
555                Err(e) => {
556                    writer
557                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
558                        .await?;
559                    return Ok(());
560                }
561            }
562        } else if query_upcase.starts_with("EXECUTE ") {
563            match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
564                Ok((stmt_name, params)) => {
565                    let outputs = match self
566                        .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
567                        .await
568                    {
569                        Ok(outputs) => outputs,
570                        Err(e) => {
571                            let (kind, err) = handle_err(e, query_ctx);
572                            debug!(
573                                "Failed to execute prepared statement, kind: {:?}, err: {}",
574                                kind, err
575                            );
576                            writer.error(kind, err.as_bytes()).await?;
577                            return Ok(());
578                        }
579                    };
580                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
581
582                    return Ok(());
583                }
584                Err(e) => {
585                    writer
586                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
587                        .await?;
588                    return Ok(());
589                }
590            }
591        } else if query_upcase.starts_with("DEALLOCATE ") {
592            match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
593                Ok(stmt_name) => {
594                    self.do_close(stmt_name);
595                    let outputs = vec![Ok(Output::new_with_affected_rows(0))];
596                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
597                    return Ok(());
598                }
599                Err(e) => {
600                    writer
601                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
602                        .await?;
603                    return Ok(());
604                }
605            }
606        }
607
608        let outputs = self.do_query(query, query_ctx.clone()).await;
609        writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
610
611        Ok(())
612    }
613
614    async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
615        let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
616        let catalog = if let Some(catalog) = &catalog_from_db {
617            catalog.clone()
618        } else {
619            self.session.catalog()
620        };
621
622        if !self
623            .query_handler
624            .is_valid_schema(&catalog, &schema)
625            .await?
626        {
627            return w
628                .error(
629                    ErrorKind::ER_WRONG_DB_NAME,
630                    format!("Unknown database '{}'", database).as_bytes(),
631                )
632                .await
633                .map_err(|e| e.into());
634        }
635
636        let user_info = &self.session.user_info();
637
638        if let Some(schema_validator) = &self.user_provider
639            && let Err(e) = schema_validator
640                .authorize(&catalog, &schema, user_info)
641                .await
642        {
643            METRIC_AUTH_FAILURE
644                .with_label_values(&[e.status_code().as_ref()])
645                .inc();
646            return w
647                .error(
648                    ErrorKind::ER_DBACCESS_DENIED_ERROR,
649                    e.output_msg().as_bytes(),
650                )
651                .await
652                .map_err(|e| e.into());
653        }
654
655        if catalog_from_db.is_some() {
656            self.session.set_catalog(catalog)
657        }
658        self.session.set_schema(schema);
659
660        w.ok().await.map_err(|e| e.into())
661    }
662}
663
664fn convert_param_value_to_string(param: &ParamValue) -> String {
665    match param.value.into_inner() {
666        ValueInner::Int(u) => u.to_string(),
667        ValueInner::UInt(u) => u.to_string(),
668        ValueInner::Double(u) => u.to_string(),
669        ValueInner::NULL => "NULL".to_string(),
670        // MySQL prepared fallback emits SQL text. Delegate bytes/string literal
671        // escaping to mysql_common. `false` means normal MySQL backslash escapes;
672        // if NO_BACKSLASH_ESCAPES is supported in this path later, wire the
673        // session SQL mode here.
674        ValueInner::Bytes(b) => MysqlValue::Bytes(b.to_vec()).as_sql(false),
675        ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
676        ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
677        ValueInner::Time(_) => format_duration(Duration::from(param.value)),
678    }
679}
680
681fn replace_params(params: Vec<String>, stmt: Statement, mut query: String) -> Result<String> {
682    let spans = helper::placeholder_spans(stmt);
683    ensure!(
684        spans.len() == params.len(),
685        error::InternalSnafu {
686            err_msg: format!(
687                "Prepared statement expected {} parameters but got {}",
688                spans.len(),
689                params.len()
690            )
691        }
692    );
693
694    let mut replacements = Vec::with_capacity(spans.len());
695    for span in spans {
696        let start = location_to_byte_offset(&query, span.start_line, span.start_column)
697            .ok_or_else(|| {
698                error::InternalSnafu {
699                    err_msg: format!(
700                        "Invalid placeholder start span: line {}, column {}",
701                        span.start_line, span.start_column
702                    ),
703                }
704                .build()
705            })?;
706        let end =
707            location_to_byte_offset(&query, span.end_line, span.end_column).ok_or_else(|| {
708                error::InternalSnafu {
709                    err_msg: format!(
710                        "Invalid placeholder end span: line {}, column {}",
711                        span.end_line, span.end_column
712                    ),
713                }
714                .build()
715            })?;
716        let param = span
717            .index
718            .checked_sub(1)
719            .and_then(|idx| params.get(idx))
720            .ok_or_else(|| {
721                error::InternalSnafu {
722                    err_msg: format!("Missing prepared statement parameter {}", span.index),
723                }
724                .build()
725            })?;
726
727        ensure!(
728            start < end && end <= query.len(),
729            error::InternalSnafu {
730                err_msg: format!(
731                    "Invalid placeholder byte span: {}..{} for query length {}",
732                    start,
733                    end,
734                    query.len()
735                )
736            }
737        );
738        ensure!(
739            query.get(start..end) == Some("?"),
740            error::InternalSnafu {
741                err_msg: format!(
742                    "Prepared statement placeholder span maps to {:?} instead of '?'",
743                    query.get(start..end)
744                )
745            }
746        );
747
748        replacements.push((start, end, param.clone()));
749    }
750
751    replacements.sort_unstable_by_key(|(start, _, _)| *start);
752    for windows in replacements.windows(2) {
753        ensure!(
754            windows[0].1 <= windows[1].0,
755            error::InternalSnafu {
756                err_msg: "Overlapping placeholder spans in prepared statement".to_string()
757            }
758        );
759    }
760
761    // All spans are computed against the original query. Apply replacements
762    // from right to left so changing one parameter's string length never shifts
763    // the byte offsets of placeholders that have not been replaced yet.
764    for (start, end, param) in replacements.into_iter().rev() {
765        query.replace_range(start..end, &param);
766    }
767
768    Ok(query)
769}
770
771fn location_to_byte_offset(query: &str, line: u64, column: u64) -> Option<usize> {
772    // sqlparser spans are 1-based line/column locations, and columns advance by
773    // Rust `char`s rather than bytes. Convert them to byte offsets before using
774    // `String::replace_range` on the original SQL text.
775    if line == 0 || column == 0 {
776        return None;
777    }
778
779    let mut current_line = 1;
780    let mut current_column = 1;
781    for (index, ch) in query.char_indices() {
782        if current_line == line && current_column == column {
783            return Some(index);
784        }
785
786        if ch == '\n' {
787            current_line += 1;
788            current_column = 1;
789        } else {
790            current_column += 1;
791        }
792    }
793
794    // The exclusive end location of a trailing placeholder points just past
795    // the last character, for example the end span of `SELECT ?`.
796    (current_line == line && current_column == column).then_some(query.len())
797}
798
799fn format_duration(duration: Duration) -> String {
800    let seconds = duration.as_secs() % 60;
801    let minutes = (duration.as_secs() / 60) % 60;
802    let hours = (duration.as_secs() / 60) / 60;
803    format!("'{}:{}:{}'", hours, minutes, seconds)
804}
805
806fn replace_params_with_values(
807    plan: &LogicalPlan,
808    param_types: HashMap<String, Option<ConcreteDataType>>,
809    params: &[ParamValue],
810) -> Result<LogicalPlan> {
811    debug_assert_eq!(param_types.len(), params.len());
812
813    debug!(
814        "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
815        param_types,
816        params
817            .iter()
818            .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
819            .join(", "),
820        plan
821    );
822
823    let mut values = Vec::with_capacity(params.len());
824
825    for (i, param) in params.iter().enumerate() {
826        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
827            let value = helper::convert_value(param, t)?;
828
829            values.push(value.into());
830        }
831    }
832
833    plan.clone()
834        .replace_params_with_values(&ParamValues::List(values.clone()))
835        .context(DataFrameSnafu)
836}
837
838fn replace_params_with_exprs(
839    plan: &LogicalPlan,
840    param_types: HashMap<String, Option<ConcreteDataType>>,
841    params: &[sql::ast::Expr],
842) -> Result<LogicalPlan> {
843    debug_assert_eq!(param_types.len(), params.len());
844
845    debug!(
846        "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
847        param_types,
848        params.iter().map(|x| format!("({:?})", x)).join(", "),
849        plan
850    );
851
852    let mut values = Vec::with_capacity(params.len());
853
854    for (i, param) in params.iter().enumerate() {
855        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
856            let value = helper::convert_expr_to_scalar_value(param, t)?;
857
858            values.push(value.into());
859        }
860    }
861
862    plan.clone()
863        .replace_params_with_values(&ParamValues::List(values.clone()))
864        .context(DataFrameSnafu)
865}
866
867async fn validate_query(query: &str) -> Result<Statement> {
868    let statement =
869        ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
870    let mut statement = statement.map_err(|e| {
871        InvalidPrepareStatementSnafu {
872            err_msg: e.output_msg(),
873        }
874        .build()
875    })?;
876
877    ensure!(
878        statement.len() == 1,
879        InvalidPrepareStatementSnafu {
880            err_msg: "prepare statement only support single statement".to_string(),
881        }
882    );
883
884    let statement = statement.remove(0);
885
886    Ok(statement)
887}
888
889fn dummy_params(index: usize) -> Result<Vec<Column>> {
890    let mut params = Vec::with_capacity(index - 1);
891
892    for _ in 1..index {
893        params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
894    }
895
896    Ok(params)
897}
898
899/// Parameters that the client must provide when executing the prepared statement.
900fn prepared_params(
901    param_types: &HashMap<String, Option<ConcreteDataType>>,
902    param_num: usize,
903) -> Result<Vec<Column>> {
904    let mut params = Vec::with_capacity(param_num - 1);
905
906    // Placeholder index starts from 1
907    for i in 1..param_num {
908        let column = if let Some(Some(t)) = param_types.get(&format_placeholder(i)) {
909            create_mysql_column(t, "")?
910        } else {
911            create_mysql_column(&ConcreteDataType::null_datatype(), "")?
912        };
913        params.push(column);
914    }
915
916    Ok(params)
917}
918
919fn all_params_have_types(
920    param_types: &HashMap<String, Option<ConcreteDataType>>,
921    param_num: usize,
922) -> bool {
923    param_types.len() == param_num - 1
924        && (1..param_num).all(|i| matches!(param_types.get(&format_placeholder(i)), Some(Some(_))))
925}
926
927#[cfg(test)]
928mod tests {
929    use std::sync::Arc;
930
931    use async_trait::async_trait;
932    use common_query::Output;
933    use datafusion_expr::LogicalPlan;
934    use query::parser::PromQuery;
935    use query::query_engine::DescribeResult;
936    use session::context::QueryContext;
937    use sql::statements::statement::Statement;
938
939    use super::*;
940    use crate::error::Result;
941    use crate::query_handler::sql::SqlQueryHandler;
942
943    struct DummyQueryHandler;
944
945    #[async_trait]
946    impl SqlQueryHandler for DummyQueryHandler {
947        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
948            unimplemented!()
949        }
950
951        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
952            unimplemented!()
953        }
954
955        async fn do_exec_plan(
956            &self,
957            _: LogicalPlan,
958            _: Option<Statement>,
959            _: QueryContextRef,
960        ) -> Result<Output> {
961            unimplemented!()
962        }
963
964        async fn do_describe(
965            &self,
966            _: Statement,
967            _: QueryContextRef,
968        ) -> Result<Option<DescribeResult>> {
969            unimplemented!()
970        }
971
972        async fn is_valid_schema(&self, _: &str, _: &str) -> Result<bool> {
973            Ok(true)
974        }
975    }
976
977    fn create_shim() -> MysqlInstanceShim {
978        MysqlInstanceShim::create(
979            Arc::new(DummyQueryHandler),
980            None,
981            "127.0.0.1:3306".parse().unwrap(),
982            1,
983            1024,
984        )
985    }
986
987    fn statement_with_transformed_placeholders(query: &str) -> Statement {
988        let mut statements =
989            ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default())
990                .unwrap();
991        assert_eq!(statements.len(), 1);
992        transform_placeholders_with_count(statements.remove(0)).0
993    }
994
995    #[test]
996    fn test_prepared_params_keep_unknown_type_placeholders() {
997        let mut param_types = HashMap::new();
998        param_types.insert(format_placeholder(1), None);
999        param_types.insert(
1000            format_placeholder(2),
1001            Some(ConcreteDataType::int32_datatype()),
1002        );
1003
1004        let params = prepared_params(&param_types, 3).unwrap();
1005        assert_eq!(params.len(), 2);
1006        assert!(!all_params_have_types(&param_types, 3));
1007    }
1008
1009    #[test]
1010    fn test_replace_params_by_placeholder_span() {
1011        let query = "SELECT ?, ?".to_string();
1012        let stmt = statement_with_transformed_placeholders(&query);
1013        let params = vec!["'$2 should stay'".to_string(), "'value'".to_string()];
1014
1015        assert_eq!(
1016            "SELECT '$2 should stay', 'value'",
1017            replace_params(params, stmt, query).unwrap()
1018        );
1019
1020        let query = "SELECT ?, ?, ?".to_string();
1021        let stmt = statement_with_transformed_placeholders(&query);
1022        let params = vec![
1023            "'much longer than a placeholder'".to_string(),
1024            "0".to_string(),
1025            "'also much longer than a placeholder'".to_string(),
1026        ];
1027
1028        assert_eq!(
1029            "SELECT 'much longer than a placeholder', 0, 'also much longer than a placeholder'",
1030            replace_params(params, stmt, query).unwrap()
1031        );
1032
1033        let query = "SELECT '$1', \"$2\", `$3`, ?, ?".to_string();
1034        let stmt = statement_with_transformed_placeholders(&query);
1035        let params = vec!["'1'".to_string(), "'2'".to_string()];
1036
1037        assert_eq!(
1038            "SELECT '$1', \"$2\", `$3`, '1', '2'",
1039            replace_params(params, stmt, query).unwrap()
1040        );
1041
1042        let query = "SELECT /* ? */ ? -- ?\n, ?".to_string();
1043        let stmt = statement_with_transformed_placeholders(&query);
1044        let params = vec!["'first'".to_string(), "'second'".to_string()];
1045
1046        assert_eq!(
1047            "SELECT /* ? */ 'first' -- ?\n, 'second'",
1048            replace_params(params, stmt, query).unwrap()
1049        );
1050
1051        let query = "SELECT '中文', ?".to_string();
1052        let stmt = statement_with_transformed_placeholders(&query);
1053        let params = vec!["'value'".to_string()];
1054
1055        assert_eq!(
1056            "SELECT '中文', 'value'",
1057            replace_params(params, stmt, query).unwrap()
1058        );
1059
1060        let query = "SELECT '中文',\n  ?".to_string();
1061        let stmt = statement_with_transformed_placeholders(&query);
1062        let params = vec!["'value'".to_string()];
1063
1064        assert_eq!(
1065            "SELECT '中文',\n  'value'",
1066            replace_params(params, stmt, query).unwrap()
1067        );
1068
1069        let query = "SELECT 'x'\r\n, ?".to_string();
1070        let stmt = statement_with_transformed_placeholders(&query);
1071        let params = vec!["'crlf'".to_string()];
1072
1073        assert_eq!(
1074            "SELECT 'x'\r\n, 'crlf'",
1075            replace_params(params, stmt, query).unwrap()
1076        );
1077
1078        let query = "SELECT\t?".to_string();
1079        let stmt = statement_with_transformed_placeholders(&query);
1080        let params = vec!["NULL".to_string()];
1081
1082        assert_eq!("SELECT\tNULL", replace_params(params, stmt, query).unwrap());
1083
1084        let query = "SELECT CAST(? AS INT64), ? + (SELECT ?)".to_string();
1085        let stmt = statement_with_transformed_placeholders(&query);
1086        let params = vec!["1".to_string(), "2".to_string(), "3".to_string()];
1087
1088        assert_eq!(
1089            "SELECT CAST(1 AS INT64), 2 + (SELECT 3)",
1090            replace_params(params, stmt, query).unwrap()
1091        );
1092
1093        let query = "SET time_zone = ?".to_string();
1094        let stmt = statement_with_transformed_placeholders(&query);
1095        let params = vec!["'UTC'".to_string()];
1096
1097        assert_eq!(
1098            "SET time_zone = 'UTC'",
1099            replace_params(params, stmt, query).unwrap()
1100        );
1101    }
1102
1103    #[tokio::test]
1104    async fn test_prepare_federated_query() {
1105        let mut shim = create_shim();
1106        let query_ctx = QueryContext::arc();
1107        let stmt_key = "test_federated".to_string();
1108
1109        let (params, columns) = shim
1110            .do_prepare(
1111                "SELECT @@version_comment",
1112                query_ctx.clone(),
1113                stmt_key.clone(),
1114            )
1115            .await
1116            .unwrap();
1117
1118        assert!(params.is_empty());
1119        assert!(columns.is_empty());
1120
1121        let plan = shim.plan(&stmt_key).unwrap();
1122        assert!(matches!(plan, SqlPlan::Shortcut(q) if q == "SELECT @@version_comment"));
1123    }
1124
1125    #[tokio::test]
1126    async fn test_execute_federated_shortcut() {
1127        let mut shim = create_shim();
1128        let query_ctx = QueryContext::arc();
1129        let stmt_key = "test_federated_exec".to_string();
1130
1131        shim.do_prepare(
1132            "SELECT @@version_comment",
1133            query_ctx.clone(),
1134            stmt_key.clone(),
1135        )
1136        .await
1137        .unwrap();
1138
1139        let outputs = shim
1140            .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
1141            .await
1142            .unwrap();
1143
1144        assert_eq!(outputs.len(), 1);
1145        let output = outputs.into_iter().next().unwrap().unwrap();
1146        let pretty = output.data.pretty_print().await;
1147        assert!(pretty.contains("GreptimeDB"));
1148    }
1149
1150    #[tokio::test]
1151    async fn test_prepare_non_federated_query_not_shortcut() {
1152        let mut shim = create_shim();
1153        let query_ctx = QueryContext::arc();
1154        let stmt_key = "test_non_federated".to_string();
1155
1156        let result = shim
1157            .do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
1158            .await;
1159
1160        assert!(result.is_ok());
1161        let plan = shim.plan(&stmt_key).unwrap();
1162        assert!(matches!(plan, SqlPlan::Shortcut(_)));
1163    }
1164
1165    #[tokio::test]
1166    async fn test_execute_set_shortcut() {
1167        let mut shim = create_shim();
1168        let query_ctx = QueryContext::arc();
1169        let stmt_key = "test_set_shortcut".to_string();
1170
1171        shim.do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
1172            .await
1173            .unwrap();
1174
1175        let outputs = shim
1176            .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
1177            .await
1178            .unwrap();
1179
1180        assert_eq!(outputs.len(), 1);
1181        let output = outputs.into_iter().next().unwrap().unwrap();
1182        match output.data {
1183            common_query::OutputData::RecordBatches(batches) => {
1184                let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
1185                assert_eq!(total_rows, 0);
1186            }
1187            other => panic!("Expected RecordBatches, got {:?}", other),
1188        }
1189    }
1190}