feat: handle Ctrl-C command in MySQL client (#6320)

* feat/answer-ctrl-c-in-mysql:
 ## Implement Connection ID-based Query Killing

 ### Key Changes:
 - **Connection ID Management:**
   - Added `connection_id` to `Session` and `QueryContext` in `src/session/src/lib.rs` and `src/session/src/context.rs`.
   - Updated `MysqlInstanceShim` and `MysqlServer` to handle `connection_id` in `src/servers/src/mysql/handler.rs` and `src/servers/src/mysql/server.rs`.

 - **KILL Statement Enhancements:**
   - Introduced `Kill` enum to handle both `ProcessId` and `ConnectionId` in `src/sql/src/statements/kill.rs`.
   - Updated `ParserContext` to parse `KILL QUERY <connection_id>` in `src/sql/src/parser.rs`.
   - Modified `StatementExecutor` to support killing queries by `connection_id` in `src/operator/src/statement/kill.rs`.

 - **Process Management:**
   - Refactored `ProcessManager` to include `connection_id` in `src/catalog/src/process_manager.rs`.
   - Added `kill_local_process` method for local query termination.

 - **Testing:**
   - Added tests for `KILL` statement parsing and execution in `src/sql/src/parser.rs`.

 ### Affected Files:
 - `Cargo.lock`, `Cargo.toml`
 - `src/catalog/src/process_manager.rs`
 - `src/frontend/src/instance.rs`
 - `src/frontend/src/stream_wrapper.rs`
 - `src/operator/src/statement.rs`
 - `src/operator/src/statement/kill.rs`
 - `src/servers/src/mysql/federated.rs`
 - `src/servers/src/mysql/handler.rs`
 - `src/servers/src/mysql/server.rs`
 - `src/servers/src/postgres.rs`
 - `src/session/src/context.rs`
 - `src/session/src/lib.rs`
 - `src/sql/src/parser.rs`
 - `src/sql/src/statements.rs`
 - `src/sql/src/statements/kill.rs`
 - `src/sql/src/statements/statement.rs`

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

 Conflicts:
	Cargo.lock
	Cargo.toml

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* feat/answer-ctrl-c-in-mysql:
 ### Enhance Process Management and Execution

 - **`process_manager.rs`**: Added a new method `find_processes_by_connection_id` to filter processes by connection ID, improving process management capabilities.
 - **`kill.rs`**: Refactored the process killing logic to utilize the new `find_processes_by_connection_id` method, streamlining the execution flow and reducing redundant checks.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* feat/answer-ctrl-c-in-mysql:
 ## Commit Message

 ### Update Process ID Type and Refactor Code

 - **Change Process ID Type**: Updated the process ID type from `u64` to `u32` across multiple files to optimize memory usage. Affected files include `process_manager.rs`, `lib.rs`, `database.rs`, `instance.rs`, `server.rs`, `stream_wrapper.rs`, `kill.rs`, `federated.rs`, `handler.rs`, `server.rs`,
 `postgres.rs`, `mysql_server_test.rs`, `context.rs`, `lib.rs`, and `test_util.rs`.

 - **Remove Connection ID**: Removed the `connection_id` field and related logic from `process_manager.rs`, `lib.rs`, `instance.rs`, `server.rs`, `stream_wrapper.rs`, `kill.rs`, `federated.rs`, `handler.rs`, `server.rs`, `postgres.rs`, `mysql_server_test.rs`, `context.rs`, `lib.rs`, and `test_util.rs` to
 simplify the codebase.

 - **Refactor Process Management**: Refactored process management logic to improve clarity and maintainability in `process_manager.rs`, `kill.rs`, and `handler.rs`.

 - **Enhance MySQL Server Handling**: Improved MySQL server handling by integrating process management in `server.rs` and `mysql_server_test.rs`.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* feat/answer-ctrl-c-in-mysql:
 ### Add Process Manager to Postgres Server

 - **`src/frontend/src/server.rs`**: Updated server initialization to include `process_manager`.
 - **`src/servers/src/postgres.rs`**: Modified `MakePostgresServerHandler` to accept `process_id` for session creation.
 - **`src/servers/src/postgres/server.rs`**: Integrated `process_manager` into `PostgresServer` for generating `process_id` during connection handling.
 - **`src/servers/tests/postgres/mod.rs`** and **`tests-integration/src/test_util.rs`**: Adjusted test server setup to accommodate optional `process_manager`.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* feat/answer-ctrl-c-in-mysql:
 Update `greptime-proto` Dependency

 - Updated the `greptime-proto` dependency to a new revision in both `Cargo.lock` and `Cargo.toml`.
   - `Cargo.lock`: Changed source revision from `d75a56e05a87594fe31ad5c48525e9b2124149ba` to `fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80`.
   - `Cargo.toml`: Updated the `greptime-proto` git revision to match the new commit.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

---------

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>
This commit is contained in:
Lei, HUANG
2025-06-17 14:36:23 +08:00
committed by GitHub
parent 3e3a12385c
commit ecbbd2fbdb
21 changed files with 402 additions and 68 deletions

2
Cargo.lock generated
View File

@@ -5143,7 +5143,7 @@ dependencies = [
[[package]]
name = "greptime-proto"
version = "0.1.0"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=17971523673f4fbc982510d3c9d6647ff642e16f#17971523673f4fbc982510d3c9d6647ff642e16f"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80#fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80"
dependencies = [
"prost 0.13.5",
"serde",

View File

@@ -134,7 +134,7 @@ etcd-client = "0.14"
fst = "0.4.7"
futures = "0.3"
futures-util = "0.3"
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "17971523673f4fbc982510d3c9d6647ff642e16f" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80" }
hex = "0.4"
http = "1"
humantime = "2.1"

View File

@@ -15,7 +15,7 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, RwLock};
use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
@@ -29,6 +29,7 @@ use snafu::{ensure, OptionExt, ResultExt};
use crate::error;
use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
pub type ProcessId = u32;
pub type ProcessManagerRef = Arc<ProcessManager>;
/// Query process manager.
@@ -36,9 +37,9 @@ pub struct ProcessManager {
/// Local frontend server address,
server_addr: String,
/// Next process id for local queries.
next_id: AtomicU64,
next_id: AtomicU32,
/// Running process per catalog.
catalogs: RwLock<HashMap<String, HashMap<u64, CancellableProcess>>>,
catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
/// Frontend selector to locate frontend nodes.
frontend_selector: Option<MetaClientSelector>,
}
@@ -65,9 +66,9 @@ impl ProcessManager {
schemas: Vec<String>,
query: String,
client: String,
id: Option<u64>,
query_id: Option<ProcessId>,
) -> Ticket {
let id = id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
let process = ProcessInfo {
id,
catalog: catalog.clone(),
@@ -96,12 +97,12 @@ impl ProcessManager {
}
/// Generates the next process id.
pub fn next_id(&self) -> u64 {
pub fn next_id(&self) -> u32 {
self.next_id.fetch_add(1, Ordering::Relaxed)
}
/// De-register a query from process list.
pub fn deregister_query(&self, catalog: String, id: u64) {
pub fn deregister_query(&self, catalog: String, id: ProcessId) {
if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
let process = o.get_mut().remove(&id);
debug!("Deregister process: {:?}", process);
@@ -159,26 +160,10 @@ impl ProcessManager {
&self,
server_addr: String,
catalog: String,
id: u64,
id: ProcessId,
) -> error::Result<bool> {
if server_addr == self.server_addr {
if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
if let Some(process) = catalogs.remove(&id) {
process.handle.cancel();
info!(
"Killed process, catalog: {}, id: {:?}",
process.process.catalog, process.process.id
);
PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
Ok(true)
} else {
debug!("Failed to kill process, id not found: {}", id);
Ok(false)
}
} else {
debug!("Failed to kill process, catalog not found: {}", catalog);
Ok(false)
}
self.kill_local_process(catalog, id).await
} else {
let mut nodes = self
.frontend_selector
@@ -204,12 +189,33 @@ impl ProcessManager {
Ok(true)
}
}
/// Kills local query with provided catalog and id.
pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
if let Some(process) = catalogs.remove(&id) {
process.handle.cancel();
info!(
"Killed process, catalog: {}, id: {:?}",
process.process.catalog, process.process.id
);
PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
Ok(true)
} else {
debug!("Failed to kill process, id not found: {}", id);
Ok(false)
}
} else {
debug!("Failed to kill process, catalog not found: {}", catalog);
Ok(false)
}
}
}
pub struct Ticket {
pub(crate) catalog: String,
pub(crate) manager: ProcessManagerRef,
pub(crate) id: u64,
pub(crate) id: ProcessId,
pub cancellation_handle: Arc<CancellationHandle>,
}
@@ -323,7 +329,7 @@ mod tests {
assert_eq!(running_processes.len(), 2);
// Verify both processes are present
let ids: Vec<u64> = running_processes.iter().map(|p| p.id).collect();
let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
assert!(ids.contains(&ticket1.id));
assert!(ids.contains(&ticket2.id));
}

View File

@@ -23,7 +23,7 @@ pub mod selector;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct DisplayProcessId {
pub server_addr: String,
pub id: u64,
pub id: u32,
}
impl Display for DisplayProcessId {
@@ -44,7 +44,7 @@ impl TryFrom<&str> for DisplayProcessId {
let id = split
.next()
.context(error::ParseProcessIdSnafu { s: value })?;
let id = u64::from_str(id)
let id = u32::from_str(id)
.ok()
.context(error::ParseProcessIdSnafu { s: value })?;
Ok(DisplayProcessId { server_addr, id })

View File

@@ -18,7 +18,7 @@ use std::sync::Arc;
use common_query::error::Result;
use common_query::prelude::{Signature, Volatility};
use datatypes::prelude::{ConcreteDataType, ScalarVector};
use datatypes::vectors::{StringVector, UInt64Vector, VectorRef};
use datatypes::vectors::{StringVector, UInt32Vector, VectorRef};
use derive_more::Display;
use crate::function::{Function, FunctionContext};
@@ -144,7 +144,7 @@ impl Function for PgBackendPidFunction {
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let pid = func_ctx.query_ctx.process_id();
Ok(Arc::new(UInt64Vector::from_slice([pid])) as _)
Ok(Arc::new(UInt32Vector::from_slice([pid])) as _)
}
}
@@ -164,7 +164,7 @@ impl Function for ConnectionIdFunction {
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
let pid = func_ctx.query_ctx.process_id();
Ok(Arc::new(UInt64Vector::from_slice([pid])) as _)
Ok(Arc::new(UInt32Vector::from_slice([pid])) as _)
}
}

View File

@@ -235,6 +235,7 @@ where
opts.keep_alive.as_secs(),
opts.reject_no_database.unwrap_or(false),
)),
Some(instance.process_manager().clone()),
);
handlers.insert((mysql_server, mysql_addr));
}
@@ -257,6 +258,7 @@ where
opts.keep_alive.as_secs(),
common_runtime::global_runtime(),
user_provider.clone(),
Some(self.instance.process_manager().clone()),
)) as Box<dyn Server>;
handlers.insert((pg_server, pg_addr));

View File

@@ -369,7 +369,7 @@ impl StatementExecutor {
Statement::ShowSearchPath(_) => self.show_search_path(query_ctx).await,
Statement::Use(db) => self.use_database(db, query_ctx).await,
Statement::Admin(admin) => self.execute_admin_command(admin, query_ctx).await,
Statement::Kill(id) => self.execute_kill(query_ctx, id).await,
Statement::Kill(kill) => self.execute_kill(query_ctx, kill).await,
}
}

View File

@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use catalog::process_manager::ProcessManagerRef;
use common_frontend::DisplayProcessId;
use common_query::Output;
use common_telemetry::error;
use session::context::QueryContextRef;
use snafu::ResultExt;
use sql::statements::kill::Kill;
use crate::error;
use crate::statement::StatementExecutor;
@@ -25,22 +27,51 @@ impl StatementExecutor {
pub async fn execute_kill(
&self,
query_ctx: QueryContextRef,
process_id: String,
) -> crate::error::Result<Output> {
kill: Kill,
) -> error::Result<Output> {
let Some(process_manager) = self.process_manager.as_ref() else {
error!("Process manager is not initialized");
return error::ProcessManagerMissingSnafu.fail();
};
let succ = match kill {
Kill::ProcessId(process_id) => {
self.kill_process_id(process_manager, query_ctx, process_id)
.await?
}
Kill::ConnectionId(conn_id) => {
self.kill_connection_id(process_manager, query_ctx, conn_id)
.await?
}
};
Ok(Output::new_with_affected_rows(if succ { 1 } else { 0 }))
}
/// Handles `KILL <PROCESS_ID>` statements.
async fn kill_process_id(
&self,
pm: &ProcessManagerRef,
query_ctx: QueryContextRef,
process_id: String,
) -> error::Result<bool> {
let display_id = DisplayProcessId::try_from(process_id.as_str())
.map_err(|_| error::InvalidProcessIdSnafu { id: process_id }.build())?;
let current_user_catalog = query_ctx.current_catalog().to_string();
process_manager
.kill_process(display_id.server_addr, current_user_catalog, display_id.id)
pm.kill_process(display_id.server_addr, current_user_catalog, display_id.id)
.await
.context(error::CatalogSnafu)?;
.context(error::CatalogSnafu)
}
Ok(Output::new_with_affected_rows(0))
/// Handles MySQL `KILL QUERY <CONNECTION_ID>` statements.
pub async fn kill_connection_id(
&self,
pm: &ProcessManagerRef,
query_ctx: QueryContextRef,
connection_id: u32,
) -> error::Result<bool> {
pm.kill_local_process(query_ctx.current_catalog().to_string(), connection_id)
.await
.context(error::CatalogSnafu)
}
}

View File

@@ -82,6 +82,7 @@ pub struct MysqlInstanceShim {
user_provider: Option<UserProviderRef>,
prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
prepared_stmts_counter: AtomicU32,
process_id: u32,
}
impl MysqlInstanceShim {
@@ -89,6 +90,7 @@ impl MysqlInstanceShim {
query_handler: ServerSqlQueryHandlerRef,
user_provider: Option<UserProviderRef>,
client_addr: SocketAddr,
process_id: u32,
) -> MysqlInstanceShim {
// init a random salt
let mut bs = vec![0u8; 20];
@@ -110,12 +112,12 @@ impl MysqlInstanceShim {
Some(client_addr),
Channel::Mysql,
Default::default(),
// TODO(sunng87): generate process id properly
0,
process_id,
)),
user_provider,
prepared_stmts: Default::default(),
prepared_stmts_counter: AtomicU32::new(1),
process_id,
}
}
@@ -341,6 +343,10 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
}
fn connect_id(&self) -> u32 {
self.process_id
}
fn default_auth_plugin(&self) -> &str {
self.auth_plugin()
}

View File

@@ -18,6 +18,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use auth::UserProviderRef;
use catalog::process_manager::ProcessManagerRef;
use common_runtime::runtime::RuntimeTrait;
use common_runtime::Runtime;
use common_telemetry::{debug, warn};
@@ -112,6 +113,7 @@ pub struct MysqlServer {
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
bind_addr: Option<SocketAddr>,
process_manager: Option<ProcessManagerRef>,
}
impl MysqlServer {
@@ -119,16 +121,23 @@ impl MysqlServer {
io_runtime: Runtime,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
process_manager: Option<ProcessManagerRef>,
) -> Box<dyn Server> {
Box::new(MysqlServer {
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
spawn_ref,
spawn_config,
bind_addr: None,
process_manager,
})
}
fn accept(&self, io_runtime: Runtime, stream: AbortableStream) -> impl Future<Output = ()> {
fn accept(
&self,
io_runtime: Runtime,
stream: AbortableStream,
process_manager: Option<ProcessManagerRef>,
) -> impl Future<Output = ()> {
let spawn_ref = self.spawn_ref.clone();
let spawn_config = self.spawn_config.clone();
@@ -136,7 +145,7 @@ impl MysqlServer {
let spawn_ref = spawn_ref.clone();
let spawn_config = spawn_config.clone();
let io_runtime = io_runtime.clone();
let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
async move {
match tcp_stream {
Err(e) => warn!(e; "Broken pipe"), // IoError doesn't impl ErrorExt.
@@ -146,7 +155,7 @@ impl MysqlServer {
}
io_runtime.spawn(async move {
if let Err(error) =
Self::handle(io_stream, spawn_ref, spawn_config).await
Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
{
warn!(error; "Unexpected error when handling TcpStream");
};
@@ -161,10 +170,11 @@ impl MysqlServer {
stream: TcpStream,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
process_id: u32,
) -> Result<()> {
debug!("MySQL connection coming from: {}", stream.peer_addr()?);
crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await {
if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
if let Error::InternalIo { error } = &e
&& error.kind() == std::io::ErrorKind::ConnectionAborted
{
@@ -184,11 +194,13 @@ impl MysqlServer {
stream: TcpStream,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
process_id: u32,
) -> Result<()> {
let mut shim = MysqlInstanceShim::create(
spawn_ref.query_handler(),
spawn_ref.user_provider(),
stream.peer_addr()?,
process_id,
);
let (mut r, w) = stream.into_split();
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
@@ -230,7 +242,11 @@ impl Server for MysqlServer {
.await?;
let io_runtime = self.base_server.io_runtime();
let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
let join_handle = common_runtime::spawn_global(self.accept(
io_runtime,
stream,
self.process_manager.clone(),
));
self.base_server.start_with(join_handle).await?;
self.bind_addr = Some(addr);

View File

@@ -119,9 +119,13 @@ impl PgWireServerHandlers for PostgresServerHandler {
}
impl MakePostgresServerHandler {
fn make(&self, addr: Option<SocketAddr>) -> PostgresServerHandler {
// TODO(sunng87): generate pid from process manager
let session = Arc::new(Session::new(addr, Channel::Postgres, Default::default(), 0));
fn make(&self, addr: Option<SocketAddr>, process_id: u32) -> PostgresServerHandler {
let session = Arc::new(Session::new(
addr,
Channel::Postgres,
Default::default(),
process_id,
));
let handler = PostgresServerHandlerInner {
query_handler: self.query_handler.clone(),
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),

View File

@@ -18,6 +18,7 @@ use std::sync::Arc;
use ::auth::UserProviderRef;
use async_trait::async_trait;
use catalog::process_manager::ProcessManagerRef;
use common_runtime::runtime::RuntimeTrait;
use common_runtime::Runtime;
use common_telemetry::{debug, warn};
@@ -37,6 +38,7 @@ pub struct PostgresServer {
tls_server_config: Arc<ReloadableTlsServerConfig>,
keep_alive_secs: u64,
bind_addr: Option<SocketAddr>,
process_manager: Option<ProcessManagerRef>,
}
impl PostgresServer {
@@ -48,6 +50,7 @@ impl PostgresServer {
keep_alive_secs: u64,
io_runtime: Runtime,
user_provider: Option<UserProviderRef>,
process_manager: Option<ProcessManagerRef>,
) -> PostgresServer {
let make_handler = Arc::new(
MakePostgresServerHandlerBuilder::default()
@@ -63,6 +66,7 @@ impl PostgresServer {
tls_server_config,
keep_alive_secs,
bind_addr: None,
process_manager,
}
}
@@ -73,12 +77,12 @@ impl PostgresServer {
) -> impl Future<Output = ()> {
let handler_maker = self.make_handler.clone();
let tls_server_config = self.tls_server_config.clone();
let process_manager = self.process_manager.clone();
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
let handler_maker = handler_maker.clone();
let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
async move {
match tcp_stream {
@@ -97,7 +101,7 @@ impl PostgresServer {
let _handle = io_runtime.spawn(async move {
crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
let pg_handler = Arc::new(handler_maker.make(addr));
let pg_handler = Arc::new(handler_maker.make(addr, process_id));
let r =
process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();

View File

@@ -73,6 +73,7 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result<Box<dyn S
0,
opts.reject_no_database,
)),
None,
))
}

View File

@@ -71,6 +71,7 @@ fn create_postgres_server(
0,
io_runtime,
user_provider,
None,
)))
}

View File

@@ -66,7 +66,7 @@ pub struct QueryContext {
channel: Channel,
/// Process id for managing on-going queries
#[builder(default)]
process_id: u64,
process_id: u32,
/// Connection information
#[builder(default)]
conn_info: ConnInfo,
@@ -439,7 +439,7 @@ impl QueryContext {
.copied()
}
pub fn process_id(&self) -> u64 {
pub fn process_id(&self) -> u32 {
self.process_id
}

View File

@@ -41,8 +41,8 @@ pub struct Session {
mutable_inner: Arc<RwLock<MutableInner>>,
conn_info: ConnInfo,
configuration_variables: Arc<ConfigurationVariables>,
// the connection id to use when killing the query
process_id: u64,
// the process id to use when killing the query
process_id: u32,
}
pub type SessionRef = Arc<Session>;
@@ -77,7 +77,7 @@ impl Session {
addr: Option<SocketAddr>,
channel: Channel,
configuration_variables: ConfigurationVariables,
process_id: u64,
process_id: u32,
) -> Self {
Session {
catalog: RwLock::new(DEFAULT_CATALOG_NAME.into()),
@@ -152,7 +152,7 @@ impl Session {
build_db_string(&self.catalog(), &self.schema())
}
pub fn process_id(&self) -> u64 {
pub fn process_id(&self) -> u32 {
self.process_id
}
}

View File

@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::str::FromStr;
use snafu::ResultExt;
use sqlparser::ast::{Ident, Query};
use sqlparser::ast::{Ident, Query, Value};
use sqlparser::dialect::Dialect;
use sqlparser::keywords::Keyword;
use sqlparser::parser::{Parser, ParserError, ParserOptions};
@@ -22,6 +24,7 @@ use sqlparser::tokenizer::{Token, TokenWithSpan};
use crate::ast::{Expr, ObjectName};
use crate::error::{self, Result, SyntaxSnafu};
use crate::parsers::tql_parser;
use crate::statements::kill::Kill;
use crate::statements::statement::Statement;
use crate::statements::transform_statements;
@@ -190,14 +193,43 @@ impl ParserContext<'_> {
Keyword::KILL => {
let _ = self.parser.next_token();
let process_id_ident =
self.parser.parse_literal_string().with_context(|_| {
error::UnexpectedSnafu {
expected: "process id string literal",
actual: self.peek_token_as_string(),
let kill = if self.parser.parse_keyword(Keyword::QUERY) {
// MySQL KILL QUERY <connection id> statements
let connection_id_exp =
self.parser.parse_number_value().with_context(|_| {
error::UnexpectedSnafu {
expected: "MySQL numeric connection id",
actual: self.peek_token_as_string(),
}
})?;
let Value::Number(s, _) = connection_id_exp else {
return error::UnexpectedTokenSnafu {
expected: "MySQL numeric connection id",
actual: connection_id_exp.to_string(),
}
.fail();
};
let connection_id = u32::from_str(&s).map_err(|_| {
error::UnexpectedTokenSnafu {
expected: "MySQL numeric connection id",
actual: s,
}
.build()
})?;
Ok(Statement::Kill(process_id_ident))
Kill::ConnectionId(connection_id)
} else {
let process_id_ident =
self.parser.parse_literal_string().with_context(|_| {
error::UnexpectedSnafu {
expected: "process id string literal",
actual: self.peek_token_as_string(),
}
})?;
Kill::ProcessId(process_id_ident)
};
Ok(Statement::Kill(kill))
}
_ => self.unsupported(self.peek_token_as_string()),
@@ -440,4 +472,191 @@ mod tests {
let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
assert_eq!(stmt_name, "stmt2");
}
#[test]
pub fn test_parse_kill_query_statement() {
use crate::statements::kill::Kill;
// Test MySQL-style KILL QUERY with connection ID
let sql = "KILL QUERY 123";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ConnectionId(connection_id)) => {
assert_eq!(*connection_id, 123);
}
_ => panic!("Expected Kill::ConnectionId statement"),
}
// Test with larger connection ID
let sql = "KILL QUERY 999999";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ConnectionId(connection_id)) => {
assert_eq!(*connection_id, 999999);
}
_ => panic!("Expected Kill::ConnectionId statement"),
}
}
#[test]
pub fn test_parse_kill_process_statement() {
use crate::statements::kill::Kill;
// Test KILL with process ID string
let sql = "KILL 'process-123'";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ProcessId(process_id)) => {
assert_eq!(process_id, "process-123");
}
_ => panic!("Expected Kill::ProcessId statement"),
}
// Test with double quotes
let sql = "KILL \"process-456\"";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ProcessId(process_id)) => {
assert_eq!(process_id, "process-456");
}
_ => panic!("Expected Kill::ProcessId statement"),
}
// Test with UUID-like process ID
let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ProcessId(process_id)) => {
assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
}
_ => panic!("Expected Kill::ProcessId statement"),
}
}
#[test]
pub fn test_parse_kill_statement_errors() {
// Test KILL QUERY without connection ID
let sql = "KILL QUERY";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
assert!(result.is_err());
// Test KILL QUERY with non-numeric connection ID
let sql = "KILL QUERY 'not-a-number'";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
assert!(result.is_err());
// Test KILL without any argument
let sql = "KILL";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
assert!(result.is_err());
// Test KILL QUERY with connection ID that's too large for u32
let sql = "KILL QUERY 4294967296"; // u32::MAX + 1
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
assert!(result.is_err());
}
#[test]
pub fn test_parse_kill_statement_edge_cases() {
use crate::statements::kill::Kill;
// Test KILL QUERY with zero connection ID
let sql = "KILL QUERY 0";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ConnectionId(connection_id)) => {
assert_eq!(*connection_id, 0);
}
_ => panic!("Expected Kill::ConnectionId statement"),
}
// Test KILL QUERY with maximum u32 value
let sql = "KILL QUERY 4294967295"; // u32::MAX
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ConnectionId(connection_id)) => {
assert_eq!(*connection_id, 4294967295);
}
_ => panic!("Expected Kill::ConnectionId statement"),
}
// Test KILL with empty string process ID
let sql = "KILL ''";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ProcessId(process_id)) => {
assert_eq!(process_id, "");
}
_ => panic!("Expected Kill::ProcessId statement"),
}
}
#[test]
pub fn test_parse_kill_statement_case_insensitive() {
use crate::statements::kill::Kill;
// Test lowercase
let sql = "kill query 123";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ConnectionId(connection_id)) => {
assert_eq!(*connection_id, 123);
}
_ => panic!("Expected Kill::ConnectionId statement"),
}
// Test mixed case
let sql = "Kill Query 456";
let statements =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap();
assert_eq!(statements.len(), 1);
match &statements[0] {
Statement::Kill(Kill::ConnectionId(connection_id)) => {
assert_eq!(*connection_id, 456);
}
_ => panic!("Expected Kill::ConnectionId statement"),
}
}
}

View File

@@ -22,6 +22,7 @@ pub mod describe;
pub mod drop;
pub mod explain;
pub mod insert;
pub mod kill;
mod option_map;
pub mod query;
pub mod set_variables;

View File

@@ -0,0 +1,40 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::{Display, Formatter};
use serde::Serialize;
use sqlparser_derive::{Visit, VisitMut};
/// Arguments of `KILL` statements.
#[derive(Debug, Clone, Eq, PartialEq, Visit, VisitMut, Serialize)]
pub enum Kill {
/// Kill a remote process id.
ProcessId(String),
/// Kill MySQL connection id.
ConnectionId(u32),
}
impl Display for Kill {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Kill::ProcessId(id) => {
write!(f, "KILL {}", id)
}
Kill::ConnectionId(id) => {
write!(f, "KILL QUERY {}", id)
}
}
}
}

View File

@@ -32,6 +32,7 @@ use crate::statements::describe::DescribeTable;
use crate::statements::drop::{DropDatabase, DropFlow, DropTable, DropView};
use crate::statements::explain::Explain;
use crate::statements::insert::Insert;
use crate::statements::kill::Kill;
use crate::statements::query::Query;
use crate::statements::set_variables::SetVariables;
use crate::statements::show::{
@@ -139,7 +140,7 @@ pub enum Statement {
// CLOSE
CloseCursor(CloseCursor),
// KILL <process>
Kill(String),
Kill(Kill),
}
impl Display for Statement {

View File

@@ -648,6 +648,7 @@ pub async fn setup_mysql_server_with_user_provider(
0,
opts.reject_no_database.unwrap_or(false),
)),
None,
);
mysql_server
@@ -697,6 +698,7 @@ pub async fn setup_pg_server_with_user_provider(
0,
runtime,
user_provider,
None,
));
pg_server