feat: use regex to filter out not supported MySQL stmt (#396)

* feat: use regex to filter out not supported MySQL stmt

* fix: resolve PR comments

Co-authored-by: luofucong <luofucong@greptime.com>
This commit is contained in:
LFC
2022-11-08 11:09:46 +08:00
committed by GitHub
parent 89a3b39728
commit f34a99ff5a
8 changed files with 454 additions and 20 deletions

View File

@@ -9,7 +9,7 @@ repos:
rev: e6a795bc6b2c0958f9ef52af4863bbd7cc17238f
hooks:
- id: cargo-sort
args: ["--workspace"]
args: ["--workspace", "--print"]
- repo: https://github.com/doublify/pre-commit-rust
rev: v1.0

6
Cargo.lock generated
View File

@@ -3350,9 +3350,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.15.0"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
[[package]]
name = "oorandom"
@@ -5030,12 +5030,14 @@ dependencies = [
"metrics",
"mysql_async",
"num_cpus",
"once_cell",
"openmetrics-parser",
"opensrv-mysql",
"pgwire",
"prost 0.11.0",
"query",
"rand 0.8.5",
"regex",
"schemars",
"script",
"serde",

View File

@@ -4,9 +4,12 @@ mod recordbatch;
pub mod util;
use std::pin::Pin;
use std::sync::Arc;
use datafusion::arrow_print;
pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
use datatypes::schema::SchemaRef;
use datatypes::prelude::VectorRef;
use datatypes::schema::{Schema, SchemaRef};
use error::Result;
use futures::task::{Context, Poll};
use futures::Stream;
@@ -54,6 +57,35 @@ pub struct RecordBatches {
}
impl RecordBatches {
pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
schema: SchemaRef,
columns: I,
) -> Result<Self> {
let batches = vec![RecordBatch::new(schema.clone(), columns)?];
Ok(Self { schema, batches })
}
#[inline]
pub fn empty() -> Self {
Self {
schema: Arc::new(Schema::new(vec![])),
batches: vec![],
}
}
pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
self.batches.iter()
}
pub fn pretty_print(&self) -> String {
arrow_print::write(
&self
.iter()
.map(|x| x.df_recordbatch.clone())
.collect::<Vec<_>>(),
)
}
pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
for batch in batches.iter() {
ensure!(
@@ -124,7 +156,26 @@ mod tests {
use super::*;
#[test]
fn test_recordbatches() {
fn test_recordbatches_try_from_columns() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
"a",
ConcreteDataType::int32_datatype(),
false,
)]));
let result = RecordBatches::try_from_columns(
schema.clone(),
vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
);
assert!(result.is_err());
let v: VectorRef = Arc::new(Int32Vector::from_slice(&[1, 2]));
let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
assert_eq!(r.take(), expected);
}
#[test]
fn test_recordbatches_try_new() {
let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
@@ -150,6 +201,15 @@ mod tests {
);
let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
let expected = "\
+---+-------+
| a | b |
+---+-------+
| 1 | hello |
| 2 | world |
+---+-------+";
assert_eq!(batches.pretty_print(), expected);
assert_eq!(schema1, batches.schema());
assert_eq!(vec![batch1], batches.take());
}

View File

@@ -285,8 +285,6 @@ mod tests {
admin_expr, admin_result, column, column::SemanticType, object_expr, object_result,
select_expr, Column, ExprHeader, MutateResult, SelectExpr,
};
use datafusion::arrow_print;
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datatypes::schema::ColumnDefaultConstraint;
use datatypes::value::Value;
@@ -327,12 +325,7 @@ mod tests {
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
match output {
Output::RecordBatches(recordbatches) => {
let recordbatches = recordbatches
.take()
.into_iter()
.map(|r| r.df_recordbatch)
.collect::<Vec<DfRecordBatch>>();
let pretty_print = arrow_print::write(&recordbatches);
let pretty_print = recordbatches.pretty_print();
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
let expected = vec![
"+----------------+---------------------+-----+--------+-----------+",
@@ -352,12 +345,7 @@ mod tests {
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
match output {
Output::RecordBatches(recordbatches) => {
let recordbatches = recordbatches
.take()
.into_iter()
.map(|r| r.df_recordbatch)
.collect::<Vec<DfRecordBatch>>();
let pretty_print = arrow_print::write(&recordbatches);
let pretty_print = recordbatches.pretty_print();
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
let expected = vec![
"+----------------+---------------------+-----+--------+-----------+",

View File

@@ -26,10 +26,12 @@ hyper = { version = "0.14", features = ["full"] }
influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" }
metrics = "0.20"
num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.1"
pgwire = { version = "0.4" }
prost = "0.11"
regex = "1.6"
schemars = "0.8"
serde = "1.0"
serde_json = "1.0"

View File

@@ -0,0 +1,374 @@
//! Use regex to filter out some MySQL federated components' emitted statements.
//! Inspired by Databend's "[mysql_federated.rs](https://github.com/datafuselabs/databend/blob/ac706bf65845e6895141c96c0a10bad6fdc2d367/src/query/service/src/servers/mysql/mysql_federated.rs)".
use std::collections::HashMap;
use std::sync::Arc;
use common_query::Output;
use common_recordbatch::RecordBatches;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;
use once_cell::sync::Lazy;
use regex::bytes::RegexSet;
use regex::Regex;
// TODO(LFC): Include GreptimeDB's version and git commit tag etc.
const MYSQL_VERSION: &str = "8.0.26";
static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-java(.*))").unwrap());
static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap());
static SHOW_COLLATION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(show collation where(.*))").unwrap());
static SHOW_VARIABLES_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES(.*))").unwrap());
static SELECT_VERSION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)^(SELECT VERSION\(\s*\))").unwrap());
// SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP());
static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SELECT TIMEDIFF\\(NOW\\(\\), UTC_TIMESTAMP\\(\\)\\))").unwrap());
// sqlalchemy < 1.4.30
static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap());
static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
RegexSet::new(&[
// Txn.
"(?i)^(ROLLBACK(.*))",
"(?i)^(COMMIT(.*))",
"(?i)^(START(.*))",
// Set.
"(?i)^(SET NAMES(.*))",
"(?i)^(SET character_set_results(.*))",
"(?i)^(SET net_write_timeout(.*))",
"(?i)^(SET FOREIGN_KEY_CHECKS(.*))",
"(?i)^(SET AUTOCOMMIT(.*))",
"(?i)^(SET SQL_LOG_BIN(.*))",
"(?i)^(SET sql_mode(.*))",
"(?i)^(SET SQL_SELECT_LIMIT(.*))",
"(?i)^(SET @@(.*))",
"(?i)^(SHOW COLLATION)",
"(?i)^(SHOW CHARSET)",
// mysqldump.
"(?i)^(SET SESSION(.*))",
"(?i)^(SET SQL_QUOTE_SHOW_CREATE(.*))",
"(?i)^(LOCK TABLES(.*))",
"(?i)^(UNLOCK TABLES(.*))",
"(?i)^(SELECT LOGFILE_GROUP_NAME, FILE_NAME, TOTAL_EXTENTS, INITIAL_SIZE, ENGINE, EXTRA FROM INFORMATION_SCHEMA.FILES(.*))",
// mydumper.
"(?i)^(/\\*!80003 SET(.*) \\*/)$",
"(?i)^(SHOW MASTER STATUS)",
"(?i)^(SHOW ALL SLAVES STATUS)",
"(?i)^(LOCK BINLOG FOR BACKUP)",
"(?i)^(LOCK TABLES FOR BACKUP)",
"(?i)^(UNLOCK BINLOG(.*))",
"(?i)^(/\\*!40101 SET(.*) \\*/)$",
// DBeaver.
"(?i)^(SHOW WARNINGS)",
"(?i)^(/\\* ApplicationName=(.*)SHOW WARNINGS)",
"(?i)^(/\\* ApplicationName=(.*)SHOW PLUGINS)",
"(?i)^(/\\* ApplicationName=(.*)SHOW COLLATION)",
"(?i)^(/\\* ApplicationName=(.*)SHOW CHARSET)",
"(?i)^(/\\* ApplicationName=(.*)SHOW ENGINES)",
"(?i)^(/\\* ApplicationName=(.*)SELECT @@(.*))",
"(?i)^(/\\* ApplicationName=(.*)SHOW @@(.*))",
"(?i)^(/\\* ApplicationName=(.*)SET net_write_timeout(.*))",
"(?i)^(/\\* ApplicationName=(.*)SET SQL_SELECT_LIMIT(.*))",
"(?i)^(/\\* ApplicationName=(.*)SHOW VARIABLES(.*))",
// pt-toolkit
"(?i)^(/\\*!40101 SET(.*) \\*/)$",
// mysqldump 5.7.16
"(?i)^(/\\*!40100 SET(.*) \\*/)$",
"(?i)^(/\\*!40103 SET(.*) \\*/)$",
"(?i)^(/\\*!40111 SET(.*) \\*/)$",
"(?i)^(/\\*!40101 SET(.*) \\*/)$",
"(?i)^(/\\*!40014 SET(.*) \\*/)$",
"(?i)^(/\\*!40000 SET(.*) \\*/)$",
]).unwrap()
});
static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
HashMap::from([
("tx_isolation", "REPEATABLE-READ"),
("session.tx_isolation", "REPEATABLE-READ"),
("transaction_isolation", "REPEATABLE-READ"),
("session.transaction_isolation", "REPEATABLE-READ"),
("session.transaction_read_only", "0"),
("time_zone", "UTC"),
("system_time_zone", "UTC"),
("max_allowed_packet", "134217728"),
("interactive_timeout", "31536000"),
("wait_timeout", "31536000"),
("net_write_timeout", "31536000"),
("version_comment", "Greptime"),
])
});
// Recordbatches for select function.
// Format:
// |function_name|
// |value|
fn select_function(name: &str, value: &str) -> RecordBatches {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
name,
ConcreteDataType::string_datatype(),
true,
)]));
let columns = vec![Arc::new(StringVector::from(vec![value])) as _];
RecordBatches::try_from_columns(schema, columns)
// unwrap is safe because the schema and data are definitely able to form a recordbatch, they are all string type
.unwrap()
}
// Recordbatches for show variable statement.
// Format is:
// | Variable_name | Value |
// | xx | yy |
fn show_variables(name: &str, value: &str) -> RecordBatches {
let schema = Arc::new(Schema::new(vec![
ColumnSchema::new("Variable_name", ConcreteDataType::string_datatype(), true),
ColumnSchema::new("Value", ConcreteDataType::string_datatype(), true),
]));
let columns = vec![
Arc::new(StringVector::from(vec![name])) as _,
Arc::new(StringVector::from(vec![value])) as _,
];
RecordBatches::try_from_columns(schema, columns)
// unwrap is safe because the schema and data are definitely able to form a recordbatch, they are all string type
.unwrap()
}
fn select_variable(query: &str) -> Option<Output> {
let mut fields = vec![];
let mut values = vec![];
// query like "SELECT @@aa, @@bb as cc, @dd..."
let query = query.to_lowercase();
let vars: Vec<&str> = query.split("@@").collect();
if vars.len() <= 1 {
return None;
}
// skip the first "select"
for var in vars.iter().skip(1) {
let var = var.trim_matches(|c| c == ' ' || c == ',');
let var_as: Vec<&str> = var
.split(" as ")
.map(|x| {
x.trim_matches(|c| c == ' ')
.split_whitespace()
.next()
.unwrap_or("")
})
.collect();
match var_as.len() {
1 => {
// @@aa
let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0");
values.push(Arc::new(StringVector::from(vec![*value])) as _);
// field is '@@aa'
fields.push(ColumnSchema::new(
&format!("@@{}", var_as[0]),
ConcreteDataType::string_datatype(),
true,
));
}
2 => {
// @@bb as cc:
// var is 'bb'.
let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0");
values.push(Arc::new(StringVector::from(vec![*value])) as _);
// field is 'cc'.
fields.push(ColumnSchema::new(
var_as[1],
ConcreteDataType::string_datatype(),
true,
));
}
_ => return None,
}
}
let schema = Arc::new(Schema::new(fields));
// unwrap is safe because the schema and data are definitely able to form a recordbatch, they are all string type
let batches = RecordBatches::try_from_columns(schema, values).unwrap();
Some(Output::RecordBatches(batches))
}
fn check_select_variable(query: &str) -> Option<Output> {
if vec![&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
.iter()
.any(|r| r.is_match(query))
{
select_variable(query)
} else {
None
}
}
fn check_show_variables(query: &str) -> Option<Output> {
let recordbatches = if SHOW_SQL_MODE_PATTERN.is_match(query) {
Some(show_variables("sql_mode", "ONLY_FULL_GROUP_BY STRICT_TRANS_TABLES NO_ZERO_IN_DATE NO_ZERO_DATE ERROR_FOR_DIVISION_BY_ZERO NO_ENGINE_SUBSTITUTION"))
} else if SHOW_LOWER_CASE_PATTERN.is_match(query) {
Some(show_variables("lower_case_table_names", "0"))
} else if SHOW_COLLATION_PATTERN.is_match(query) || SHOW_VARIABLES_PATTERN.is_match(query) {
Some(show_variables("", ""))
} else {
None
};
recordbatches.map(Output::RecordBatches)
}
// Check for SET or others query, this is the final check of the federated query.
fn check_others(query: &str) -> Option<Output> {
if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
return Some(Output::RecordBatches(RecordBatches::empty()));
}
let recordbatches = if SELECT_VERSION_PATTERN.is_match(query) {
Some(select_function("version()", MYSQL_VERSION))
} else if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
Some(select_function(
"TIMEDIFF(NOW(), UTC_TIMESTAMP())",
"00:00:00",
))
} else {
None
};
recordbatches.map(Output::RecordBatches)
}
// Check whether the query is a federated or driver setup command,
// and return some faked results if there are any.
pub fn check(query: &str) -> Option<Output> {
// First to check the query is like "select @@variables".
let output = check_select_variable(query);
if output.is_some() {
return output;
}
// Then to check "show variables like ...".
let output = check_show_variables(query);
if output.is_some() {
return output;
}
// Last check.
check_others(query)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_check() {
let query = "select 1";
let result = check(query);
assert!(result.is_none());
let query = "select versiona";
let output = check(query);
assert!(output.is_none());
fn test(query: &str, expected: Vec<&str>) {
let output = check(query);
match output.unwrap() {
Output::RecordBatches(r) => {
assert_eq!(r.pretty_print().lines().collect::<Vec<_>>(), expected)
}
_ => unreachable!(),
}
}
let query = "select version()";
let expected = vec![
"+-----------+",
"| version() |",
"+-----------+",
"| 8.0.26 |",
"+-----------+",
];
test(query, expected);
let query = "SELECT @@version_comment LIMIT 1";
let expected = vec![
"+-------------------+",
"| @@version_comment |",
"+-------------------+",
"| Greptime |",
"+-------------------+",
];
test(query, expected);
// variables
let query = "select @@tx_isolation, @@session.tx_isolation";
let expected = vec![
"+-----------------+------------------------+",
"| @@tx_isolation | @@session.tx_isolation |",
"+-----------------+------------------------+",
"| REPEATABLE-READ | REPEATABLE-READ |",
"+-----------------+------------------------+",
];
test(query, expected);
// complex variables
let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;";
let expected = vec![
"+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+",
"| auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone | transaction_isolation | wait_timeout; |",
"+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+",
"| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | UTC | UTC | REPEATABLE-READ | 31536000 |",
"+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+",
];
test(query, expected);
let query = "show variables";
let expected = vec![
"+---------------+-------+",
"| Variable_name | Value |",
"+---------------+-------+",
"| | |",
"+---------------+-------+",
];
test(query, expected);
let query = "show variables like 'lower_case_table_names'";
let expected = vec![
"+------------------------+-------+",
"| Variable_name | Value |",
"+------------------------+-------+",
"| lower_case_table_names | 0 |",
"+------------------------+-------+",
];
test(query, expected);
let query = "show collation";
let expected = vec!["++", "++"]; // empty
test(query, expected);
let query = "SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP())";
let expected = vec![
"+----------------------------------+",
"| TIMEDIFF(NOW(), UTC_TIMESTAMP()) |",
"+----------------------------------+",
"| 00:00:00 |",
"+----------------------------------+",
];
test(query, expected);
}
}

View File

@@ -63,7 +63,14 @@ impl<W: io::Write + Send + Sync> AsyncMysqlShim<W> for MysqlInstanceShim {
query: &'a str,
writer: QueryResultWriter<'a, W>,
) -> Result<()> {
let output = self.query_handler.do_query(query).await;
// TODO(LFC): Find a better way:
// `check` uses regex to filter out unsupported statements emitted by MySQL's federated
// components, this is quick and dirty, there must be a better way to do it.
let output = if let Some(output) = crate::mysql::federated::check(query) {
Ok(output)
} else {
self.query_handler.do_query(query).await
};
let mut writer = MysqlResultWriter::new(writer);
writer.write(output).await

View File

@@ -1,3 +1,4 @@
mod federated;
pub mod handler;
pub mod server;
pub mod writer;