From ecbbd2fbdb7181433e205e005cdb7d8953262344 Mon Sep 17 00:00:00 2001 From: "Lei, HUANG" <6406592+v0y4g3r@users.noreply.github.com> Date: Tue, 17 Jun 2025 14:36:23 +0800 Subject: [PATCH] feat: handle `Ctrl-C` command in MySQL client (#6320) * feat/answer-ctrl-c-in-mysql: ## Implement Connection ID-based Query Killing ### Key Changes: - **Connection ID Management:** - Added `connection_id` to `Session` and `QueryContext` in `src/session/src/lib.rs` and `src/session/src/context.rs`. - Updated `MysqlInstanceShim` and `MysqlServer` to handle `connection_id` in `src/servers/src/mysql/handler.rs` and `src/servers/src/mysql/server.rs`. - **KILL Statement Enhancements:** - Introduced `Kill` enum to handle both `ProcessId` and `ConnectionId` in `src/sql/src/statements/kill.rs`. - Updated `ParserContext` to parse `KILL QUERY ` in `src/sql/src/parser.rs`. - Modified `StatementExecutor` to support killing queries by `connection_id` in `src/operator/src/statement/kill.rs`. - **Process Management:** - Refactored `ProcessManager` to include `connection_id` in `src/catalog/src/process_manager.rs`. - Added `kill_local_process` method for local query termination. - **Testing:** - Added tests for `KILL` statement parsing and execution in `src/sql/src/parser.rs`. ### Affected Files: - `Cargo.lock`, `Cargo.toml` - `src/catalog/src/process_manager.rs` - `src/frontend/src/instance.rs` - `src/frontend/src/stream_wrapper.rs` - `src/operator/src/statement.rs` - `src/operator/src/statement/kill.rs` - `src/servers/src/mysql/federated.rs` - `src/servers/src/mysql/handler.rs` - `src/servers/src/mysql/server.rs` - `src/servers/src/postgres.rs` - `src/session/src/context.rs` - `src/session/src/lib.rs` - `src/sql/src/parser.rs` - `src/sql/src/statements.rs` - `src/sql/src/statements/kill.rs` - `src/sql/src/statements/statement.rs` Signed-off-by: Lei, HUANG  Conflicts:  Cargo.lock  Cargo.toml Signed-off-by: Lei, HUANG * feat/answer-ctrl-c-in-mysql: ### Enhance Process Management and Execution - **`process_manager.rs`**: Added a new method `find_processes_by_connection_id` to filter processes by connection ID, improving process management capabilities. - **`kill.rs`**: Refactored the process killing logic to utilize the new `find_processes_by_connection_id` method, streamlining the execution flow and reducing redundant checks. Signed-off-by: Lei, HUANG * feat/answer-ctrl-c-in-mysql: ## Commit Message ### Update Process ID Type and Refactor Code - **Change Process ID Type**: Updated the process ID type from `u64` to `u32` across multiple files to optimize memory usage. Affected files include `process_manager.rs`, `lib.rs`, `database.rs`, `instance.rs`, `server.rs`, `stream_wrapper.rs`, `kill.rs`, `federated.rs`, `handler.rs`, `server.rs`, `postgres.rs`, `mysql_server_test.rs`, `context.rs`, `lib.rs`, and `test_util.rs`. - **Remove Connection ID**: Removed the `connection_id` field and related logic from `process_manager.rs`, `lib.rs`, `instance.rs`, `server.rs`, `stream_wrapper.rs`, `kill.rs`, `federated.rs`, `handler.rs`, `server.rs`, `postgres.rs`, `mysql_server_test.rs`, `context.rs`, `lib.rs`, and `test_util.rs` to simplify the codebase. - **Refactor Process Management**: Refactored process management logic to improve clarity and maintainability in `process_manager.rs`, `kill.rs`, and `handler.rs`. - **Enhance MySQL Server Handling**: Improved MySQL server handling by integrating process management in `server.rs` and `mysql_server_test.rs`. Signed-off-by: Lei, HUANG * feat/answer-ctrl-c-in-mysql: ### Add Process Manager to Postgres Server - **`src/frontend/src/server.rs`**: Updated server initialization to include `process_manager`. - **`src/servers/src/postgres.rs`**: Modified `MakePostgresServerHandler` to accept `process_id` for session creation. - **`src/servers/src/postgres/server.rs`**: Integrated `process_manager` into `PostgresServer` for generating `process_id` during connection handling. - **`src/servers/tests/postgres/mod.rs`** and **`tests-integration/src/test_util.rs`**: Adjusted test server setup to accommodate optional `process_manager`. Signed-off-by: Lei, HUANG * feat/answer-ctrl-c-in-mysql: Update `greptime-proto` Dependency - Updated the `greptime-proto` dependency to a new revision in both `Cargo.lock` and `Cargo.toml`. - `Cargo.lock`: Changed source revision from `d75a56e05a87594fe31ad5c48525e9b2124149ba` to `fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80`. - `Cargo.toml`: Updated the `greptime-proto` git revision to match the new commit. Signed-off-by: Lei, HUANG --------- Signed-off-by: Lei, HUANG --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/catalog/src/process_manager.rs | 60 ++--- src/common/frontend/src/lib.rs | 4 +- src/common/function/src/system/database.rs | 6 +- src/frontend/src/server.rs | 2 + src/operator/src/statement.rs | 2 +- src/operator/src/statement/kill.rs | 43 +++- src/servers/src/mysql/handler.rs | 10 +- src/servers/src/mysql/server.rs | 26 ++- src/servers/src/postgres.rs | 10 +- src/servers/src/postgres/server.rs | 10 +- src/servers/tests/mysql/mysql_server_test.rs | 1 + src/servers/tests/postgres/mod.rs | 1 + src/session/src/context.rs | 4 +- src/session/src/lib.rs | 8 +- src/sql/src/parser.rs | 233 ++++++++++++++++++- src/sql/src/statements.rs | 1 + src/sql/src/statements/kill.rs | 40 ++++ src/sql/src/statements/statement.rs | 3 +- tests-integration/src/test_util.rs | 2 + 21 files changed, 402 insertions(+), 68 deletions(-) create mode 100644 src/sql/src/statements/kill.rs diff --git a/Cargo.lock b/Cargo.lock index 5999d195ab..54327fbbb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5143,7 +5143,7 @@ dependencies = [ [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=17971523673f4fbc982510d3c9d6647ff642e16f#17971523673f4fbc982510d3c9d6647ff642e16f" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80#fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80" dependencies = [ "prost 0.13.5", "serde", diff --git a/Cargo.toml b/Cargo.toml index 702d9fdd0b..636162fd2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -134,7 +134,7 @@ etcd-client = "0.14" fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "17971523673f4fbc982510d3c9d6647ff642e16f" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80" } hex = "0.4" http = "1" humantime = "2.1" diff --git a/src/catalog/src/process_manager.rs b/src/catalog/src/process_manager.rs index 3f7950da7d..ff2db26f46 100644 --- a/src/catalog/src/process_manager.rs +++ b/src/catalog/src/process_manager.rs @@ -15,7 +15,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, RwLock}; use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo}; @@ -29,6 +29,7 @@ use snafu::{ensure, OptionExt, ResultExt}; use crate::error; use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT}; +pub type ProcessId = u32; pub type ProcessManagerRef = Arc; /// Query process manager. @@ -36,9 +37,9 @@ pub struct ProcessManager { /// Local frontend server address, server_addr: String, /// Next process id for local queries. - next_id: AtomicU64, + next_id: AtomicU32, /// Running process per catalog. - catalogs: RwLock>>, + catalogs: RwLock>>, /// Frontend selector to locate frontend nodes. frontend_selector: Option, } @@ -65,9 +66,9 @@ impl ProcessManager { schemas: Vec, query: String, client: String, - id: Option, + query_id: Option, ) -> Ticket { - let id = id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed)); + let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed)); let process = ProcessInfo { id, catalog: catalog.clone(), @@ -96,12 +97,12 @@ impl ProcessManager { } /// Generates the next process id. - pub fn next_id(&self) -> u64 { + pub fn next_id(&self) -> u32 { self.next_id.fetch_add(1, Ordering::Relaxed) } /// De-register a query from process list. - pub fn deregister_query(&self, catalog: String, id: u64) { + pub fn deregister_query(&self, catalog: String, id: ProcessId) { if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) { let process = o.get_mut().remove(&id); debug!("Deregister process: {:?}", process); @@ -159,26 +160,10 @@ impl ProcessManager { &self, server_addr: String, catalog: String, - id: u64, + id: ProcessId, ) -> error::Result { if server_addr == self.server_addr { - if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) { - if let Some(process) = catalogs.remove(&id) { - process.handle.cancel(); - info!( - "Killed process, catalog: {}, id: {:?}", - process.process.catalog, process.process.id - ); - PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc(); - Ok(true) - } else { - debug!("Failed to kill process, id not found: {}", id); - Ok(false) - } - } else { - debug!("Failed to kill process, catalog not found: {}", catalog); - Ok(false) - } + self.kill_local_process(catalog, id).await } else { let mut nodes = self .frontend_selector @@ -204,12 +189,33 @@ impl ProcessManager { Ok(true) } } + + /// Kills local query with provided catalog and id. + pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result { + if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) { + if let Some(process) = catalogs.remove(&id) { + process.handle.cancel(); + info!( + "Killed process, catalog: {}, id: {:?}", + process.process.catalog, process.process.id + ); + PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc(); + Ok(true) + } else { + debug!("Failed to kill process, id not found: {}", id); + Ok(false) + } + } else { + debug!("Failed to kill process, catalog not found: {}", catalog); + Ok(false) + } + } } pub struct Ticket { pub(crate) catalog: String, pub(crate) manager: ProcessManagerRef, - pub(crate) id: u64, + pub(crate) id: ProcessId, pub cancellation_handle: Arc, } @@ -323,7 +329,7 @@ mod tests { assert_eq!(running_processes.len(), 2); // Verify both processes are present - let ids: Vec = running_processes.iter().map(|p| p.id).collect(); + let ids: Vec = running_processes.iter().map(|p| p.id).collect(); assert!(ids.contains(&ticket1.id)); assert!(ids.contains(&ticket2.id)); } diff --git a/src/common/frontend/src/lib.rs b/src/common/frontend/src/lib.rs index ab07f54168..cfd485e25b 100644 --- a/src/common/frontend/src/lib.rs +++ b/src/common/frontend/src/lib.rs @@ -23,7 +23,7 @@ pub mod selector; #[derive(Debug, Clone, Eq, PartialEq)] pub struct DisplayProcessId { pub server_addr: String, - pub id: u64, + pub id: u32, } impl Display for DisplayProcessId { @@ -44,7 +44,7 @@ impl TryFrom<&str> for DisplayProcessId { let id = split .next() .context(error::ParseProcessIdSnafu { s: value })?; - let id = u64::from_str(id) + let id = u32::from_str(id) .ok() .context(error::ParseProcessIdSnafu { s: value })?; Ok(DisplayProcessId { server_addr, id }) diff --git a/src/common/function/src/system/database.rs b/src/common/function/src/system/database.rs index ced90b2bcb..e1cbd7162b 100644 --- a/src/common/function/src/system/database.rs +++ b/src/common/function/src/system/database.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use common_query::error::Result; use common_query::prelude::{Signature, Volatility}; use datatypes::prelude::{ConcreteDataType, ScalarVector}; -use datatypes::vectors::{StringVector, UInt64Vector, VectorRef}; +use datatypes::vectors::{StringVector, UInt32Vector, VectorRef}; use derive_more::Display; use crate::function::{Function, FunctionContext}; @@ -144,7 +144,7 @@ impl Function for PgBackendPidFunction { fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let pid = func_ctx.query_ctx.process_id(); - Ok(Arc::new(UInt64Vector::from_slice([pid])) as _) + Ok(Arc::new(UInt32Vector::from_slice([pid])) as _) } } @@ -164,7 +164,7 @@ impl Function for ConnectionIdFunction { fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { let pid = func_ctx.query_ctx.process_id(); - Ok(Arc::new(UInt64Vector::from_slice([pid])) as _) + Ok(Arc::new(UInt32Vector::from_slice([pid])) as _) } } diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 0aa8d360d6..bc64219e11 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -235,6 +235,7 @@ where opts.keep_alive.as_secs(), opts.reject_no_database.unwrap_or(false), )), + Some(instance.process_manager().clone()), ); handlers.insert((mysql_server, mysql_addr)); } @@ -257,6 +258,7 @@ where opts.keep_alive.as_secs(), common_runtime::global_runtime(), user_provider.clone(), + Some(self.instance.process_manager().clone()), )) as Box; handlers.insert((pg_server, pg_addr)); diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 7968280be0..acc9265347 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -369,7 +369,7 @@ impl StatementExecutor { Statement::ShowSearchPath(_) => self.show_search_path(query_ctx).await, Statement::Use(db) => self.use_database(db, query_ctx).await, Statement::Admin(admin) => self.execute_admin_command(admin, query_ctx).await, - Statement::Kill(id) => self.execute_kill(query_ctx, id).await, + Statement::Kill(kill) => self.execute_kill(query_ctx, kill).await, } } diff --git a/src/operator/src/statement/kill.rs b/src/operator/src/statement/kill.rs index 7da90c13cb..69d897c4c2 100644 --- a/src/operator/src/statement/kill.rs +++ b/src/operator/src/statement/kill.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use catalog::process_manager::ProcessManagerRef; use common_frontend::DisplayProcessId; use common_query::Output; use common_telemetry::error; use session::context::QueryContextRef; use snafu::ResultExt; +use sql::statements::kill::Kill; use crate::error; use crate::statement::StatementExecutor; @@ -25,22 +27,51 @@ impl StatementExecutor { pub async fn execute_kill( &self, query_ctx: QueryContextRef, - process_id: String, - ) -> crate::error::Result { + kill: Kill, + ) -> error::Result { let Some(process_manager) = self.process_manager.as_ref() else { error!("Process manager is not initialized"); return error::ProcessManagerMissingSnafu.fail(); }; + let succ = match kill { + Kill::ProcessId(process_id) => { + self.kill_process_id(process_manager, query_ctx, process_id) + .await? + } + Kill::ConnectionId(conn_id) => { + self.kill_connection_id(process_manager, query_ctx, conn_id) + .await? + } + }; + Ok(Output::new_with_affected_rows(if succ { 1 } else { 0 })) + } + + /// Handles `KILL ` statements. + async fn kill_process_id( + &self, + pm: &ProcessManagerRef, + query_ctx: QueryContextRef, + process_id: String, + ) -> error::Result { let display_id = DisplayProcessId::try_from(process_id.as_str()) .map_err(|_| error::InvalidProcessIdSnafu { id: process_id }.build())?; let current_user_catalog = query_ctx.current_catalog().to_string(); - process_manager - .kill_process(display_id.server_addr, current_user_catalog, display_id.id) + pm.kill_process(display_id.server_addr, current_user_catalog, display_id.id) .await - .context(error::CatalogSnafu)?; + .context(error::CatalogSnafu) + } - Ok(Output::new_with_affected_rows(0)) + /// Handles MySQL `KILL QUERY ` statements. + pub async fn kill_connection_id( + &self, + pm: &ProcessManagerRef, + query_ctx: QueryContextRef, + connection_id: u32, + ) -> error::Result { + pm.kill_local_process(query_ctx.current_catalog().to_string(), connection_id) + .await + .context(error::CatalogSnafu) } } diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index f82112c8a9..937deb0823 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -82,6 +82,7 @@ pub struct MysqlInstanceShim { user_provider: Option, prepared_stmts: Arc>>, prepared_stmts_counter: AtomicU32, + process_id: u32, } impl MysqlInstanceShim { @@ -89,6 +90,7 @@ impl MysqlInstanceShim { query_handler: ServerSqlQueryHandlerRef, user_provider: Option, client_addr: SocketAddr, + process_id: u32, ) -> MysqlInstanceShim { // init a random salt let mut bs = vec![0u8; 20]; @@ -110,12 +112,12 @@ impl MysqlInstanceShim { Some(client_addr), Channel::Mysql, Default::default(), - // TODO(sunng87): generate process id properly - 0, + process_id, )), user_provider, prepared_stmts: Default::default(), prepared_stmts_counter: AtomicU32::new(1), + process_id, } } @@ -341,6 +343,10 @@ impl AsyncMysqlShim for MysqlInstanceShi std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string()) } + fn connect_id(&self) -> u32 { + self.process_id + } + fn default_auth_plugin(&self) -> &str { self.auth_plugin() } diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index f274ccce32..57aef8796f 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use async_trait::async_trait; use auth::UserProviderRef; +use catalog::process_manager::ProcessManagerRef; use common_runtime::runtime::RuntimeTrait; use common_runtime::Runtime; use common_telemetry::{debug, warn}; @@ -112,6 +113,7 @@ pub struct MysqlServer { spawn_ref: Arc, spawn_config: Arc, bind_addr: Option, + process_manager: Option, } impl MysqlServer { @@ -119,16 +121,23 @@ impl MysqlServer { io_runtime: Runtime, spawn_ref: Arc, spawn_config: Arc, + process_manager: Option, ) -> Box { Box::new(MysqlServer { base_server: BaseTcpServer::create_server("MySQL", io_runtime), spawn_ref, spawn_config, bind_addr: None, + process_manager, }) } - fn accept(&self, io_runtime: Runtime, stream: AbortableStream) -> impl Future { + fn accept( + &self, + io_runtime: Runtime, + stream: AbortableStream, + process_manager: Option, + ) -> impl Future { let spawn_ref = self.spawn_ref.clone(); let spawn_config = self.spawn_config.clone(); @@ -136,7 +145,7 @@ impl MysqlServer { let spawn_ref = spawn_ref.clone(); let spawn_config = spawn_config.clone(); let io_runtime = io_runtime.clone(); - + let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8); async move { match tcp_stream { Err(e) => warn!(e; "Broken pipe"), // IoError doesn't impl ErrorExt. @@ -146,7 +155,7 @@ impl MysqlServer { } io_runtime.spawn(async move { if let Err(error) = - Self::handle(io_stream, spawn_ref, spawn_config).await + Self::handle(io_stream, spawn_ref, spawn_config, process_id).await { warn!(error; "Unexpected error when handling TcpStream"); }; @@ -161,10 +170,11 @@ impl MysqlServer { stream: TcpStream, spawn_ref: Arc, spawn_config: Arc, + process_id: u32, ) -> Result<()> { debug!("MySQL connection coming from: {}", stream.peer_addr()?); crate::metrics::METRIC_MYSQL_CONNECTIONS.inc(); - if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await { + if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await { if let Error::InternalIo { error } = &e && error.kind() == std::io::ErrorKind::ConnectionAborted { @@ -184,11 +194,13 @@ impl MysqlServer { stream: TcpStream, spawn_ref: Arc, spawn_config: Arc, + process_id: u32, ) -> Result<()> { let mut shim = MysqlInstanceShim::create( spawn_ref.query_handler(), spawn_ref.user_provider(), stream.peer_addr()?, + process_id, ); let (mut r, w) = stream.into_split(); let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); @@ -230,7 +242,11 @@ impl Server for MysqlServer { .await?; let io_runtime = self.base_server.io_runtime(); - let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream)); + let join_handle = common_runtime::spawn_global(self.accept( + io_runtime, + stream, + self.process_manager.clone(), + )); self.base_server.start_with(join_handle).await?; self.bind_addr = Some(addr); diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index 0c7dd8269a..9ae3234785 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -119,9 +119,13 @@ impl PgWireServerHandlers for PostgresServerHandler { } impl MakePostgresServerHandler { - fn make(&self, addr: Option) -> PostgresServerHandler { - // TODO(sunng87): generate pid from process manager - let session = Arc::new(Session::new(addr, Channel::Postgres, Default::default(), 0)); + fn make(&self, addr: Option, process_id: u32) -> PostgresServerHandler { + let session = Arc::new(Session::new( + addr, + Channel::Postgres, + Default::default(), + process_id, + )); let handler = PostgresServerHandlerInner { query_handler: self.query_handler.clone(), login_verifier: PgLoginVerifier::new(self.user_provider.clone()), diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index cc271e7709..a509771fcf 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use ::auth::UserProviderRef; use async_trait::async_trait; +use catalog::process_manager::ProcessManagerRef; use common_runtime::runtime::RuntimeTrait; use common_runtime::Runtime; use common_telemetry::{debug, warn}; @@ -37,6 +38,7 @@ pub struct PostgresServer { tls_server_config: Arc, keep_alive_secs: u64, bind_addr: Option, + process_manager: Option, } impl PostgresServer { @@ -48,6 +50,7 @@ impl PostgresServer { keep_alive_secs: u64, io_runtime: Runtime, user_provider: Option, + process_manager: Option, ) -> PostgresServer { let make_handler = Arc::new( MakePostgresServerHandlerBuilder::default() @@ -63,6 +66,7 @@ impl PostgresServer { tls_server_config, keep_alive_secs, bind_addr: None, + process_manager, } } @@ -73,12 +77,12 @@ impl PostgresServer { ) -> impl Future { let handler_maker = self.make_handler.clone(); let tls_server_config = self.tls_server_config.clone(); + let process_manager = self.process_manager.clone(); accepting_stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); - let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from); - let handler_maker = handler_maker.clone(); + let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0); async move { match tcp_stream { @@ -97,7 +101,7 @@ impl PostgresServer { let _handle = io_runtime.spawn(async move { crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc(); - let pg_handler = Arc::new(handler_maker.make(addr)); + let pg_handler = Arc::new(handler_maker.make(addr, process_id)); let r = process_socket(io_stream, tls_acceptor.clone(), pg_handler).await; crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec(); diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 56ddb574cc..a3d299a90a 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -73,6 +73,7 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result u64 { + pub fn process_id(&self) -> u32 { self.process_id } diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 5d4ce4db1a..7688b0b659 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -41,8 +41,8 @@ pub struct Session { mutable_inner: Arc>, conn_info: ConnInfo, configuration_variables: Arc, - // the connection id to use when killing the query - process_id: u64, + // the process id to use when killing the query + process_id: u32, } pub type SessionRef = Arc; @@ -77,7 +77,7 @@ impl Session { addr: Option, channel: Channel, configuration_variables: ConfigurationVariables, - process_id: u64, + process_id: u32, ) -> Self { Session { catalog: RwLock::new(DEFAULT_CATALOG_NAME.into()), @@ -152,7 +152,7 @@ impl Session { build_db_string(&self.catalog(), &self.schema()) } - pub fn process_id(&self) -> u64 { + pub fn process_id(&self) -> u32 { self.process_id } } diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 63c6465cc7..0da97cbcaf 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::str::FromStr; + use snafu::ResultExt; -use sqlparser::ast::{Ident, Query}; +use sqlparser::ast::{Ident, Query, Value}; use sqlparser::dialect::Dialect; use sqlparser::keywords::Keyword; use sqlparser::parser::{Parser, ParserError, ParserOptions}; @@ -22,6 +24,7 @@ use sqlparser::tokenizer::{Token, TokenWithSpan}; use crate::ast::{Expr, ObjectName}; use crate::error::{self, Result, SyntaxSnafu}; use crate::parsers::tql_parser; +use crate::statements::kill::Kill; use crate::statements::statement::Statement; use crate::statements::transform_statements; @@ -190,14 +193,43 @@ impl ParserContext<'_> { Keyword::KILL => { let _ = self.parser.next_token(); - let process_id_ident = - self.parser.parse_literal_string().with_context(|_| { - error::UnexpectedSnafu { - expected: "process id string literal", - actual: self.peek_token_as_string(), + let kill = if self.parser.parse_keyword(Keyword::QUERY) { + // MySQL KILL QUERY statements + let connection_id_exp = + self.parser.parse_number_value().with_context(|_| { + error::UnexpectedSnafu { + expected: "MySQL numeric connection id", + actual: self.peek_token_as_string(), + } + })?; + let Value::Number(s, _) = connection_id_exp else { + return error::UnexpectedTokenSnafu { + expected: "MySQL numeric connection id", + actual: connection_id_exp.to_string(), } + .fail(); + }; + + let connection_id = u32::from_str(&s).map_err(|_| { + error::UnexpectedTokenSnafu { + expected: "MySQL numeric connection id", + actual: s, + } + .build() })?; - Ok(Statement::Kill(process_id_ident)) + Kill::ConnectionId(connection_id) + } else { + let process_id_ident = + self.parser.parse_literal_string().with_context(|_| { + error::UnexpectedSnafu { + expected: "process id string literal", + actual: self.peek_token_as_string(), + } + })?; + Kill::ProcessId(process_id_ident) + }; + + Ok(Statement::Kill(kill)) } _ => self.unsupported(self.peek_token_as_string()), @@ -440,4 +472,191 @@ mod tests { let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap(); assert_eq!(stmt_name, "stmt2"); } + + #[test] + pub fn test_parse_kill_query_statement() { + use crate::statements::kill::Kill; + + // Test MySQL-style KILL QUERY with connection ID + let sql = "KILL QUERY 123"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ConnectionId(connection_id)) => { + assert_eq!(*connection_id, 123); + } + _ => panic!("Expected Kill::ConnectionId statement"), + } + + // Test with larger connection ID + let sql = "KILL QUERY 999999"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ConnectionId(connection_id)) => { + assert_eq!(*connection_id, 999999); + } + _ => panic!("Expected Kill::ConnectionId statement"), + } + } + + #[test] + pub fn test_parse_kill_process_statement() { + use crate::statements::kill::Kill; + + // Test KILL with process ID string + let sql = "KILL 'process-123'"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ProcessId(process_id)) => { + assert_eq!(process_id, "process-123"); + } + _ => panic!("Expected Kill::ProcessId statement"), + } + + // Test with double quotes + let sql = "KILL \"process-456\""; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ProcessId(process_id)) => { + assert_eq!(process_id, "process-456"); + } + _ => panic!("Expected Kill::ProcessId statement"), + } + + // Test with UUID-like process ID + let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ProcessId(process_id)) => { + assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479"); + } + _ => panic!("Expected Kill::ProcessId statement"), + } + } + + #[test] + pub fn test_parse_kill_statement_errors() { + // Test KILL QUERY without connection ID + let sql = "KILL QUERY"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + assert!(result.is_err()); + + // Test KILL QUERY with non-numeric connection ID + let sql = "KILL QUERY 'not-a-number'"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + assert!(result.is_err()); + + // Test KILL without any argument + let sql = "KILL"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + assert!(result.is_err()); + + // Test KILL QUERY with connection ID that's too large for u32 + let sql = "KILL QUERY 4294967296"; // u32::MAX + 1 + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + assert!(result.is_err()); + } + + #[test] + pub fn test_parse_kill_statement_edge_cases() { + use crate::statements::kill::Kill; + + // Test KILL QUERY with zero connection ID + let sql = "KILL QUERY 0"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ConnectionId(connection_id)) => { + assert_eq!(*connection_id, 0); + } + _ => panic!("Expected Kill::ConnectionId statement"), + } + + // Test KILL QUERY with maximum u32 value + let sql = "KILL QUERY 4294967295"; // u32::MAX + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ConnectionId(connection_id)) => { + assert_eq!(*connection_id, 4294967295); + } + _ => panic!("Expected Kill::ConnectionId statement"), + } + + // Test KILL with empty string process ID + let sql = "KILL ''"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ProcessId(process_id)) => { + assert_eq!(process_id, ""); + } + _ => panic!("Expected Kill::ProcessId statement"), + } + } + + #[test] + pub fn test_parse_kill_statement_case_insensitive() { + use crate::statements::kill::Kill; + + // Test lowercase + let sql = "kill query 123"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ConnectionId(connection_id)) => { + assert_eq!(*connection_id, 123); + } + _ => panic!("Expected Kill::ConnectionId statement"), + } + + // Test mixed case + let sql = "Kill Query 456"; + let statements = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap(); + + assert_eq!(statements.len(), 1); + match &statements[0] { + Statement::Kill(Kill::ConnectionId(connection_id)) => { + assert_eq!(*connection_id, 456); + } + _ => panic!("Expected Kill::ConnectionId statement"), + } + } } diff --git a/src/sql/src/statements.rs b/src/sql/src/statements.rs index f68c9434e9..8bed705b14 100644 --- a/src/sql/src/statements.rs +++ b/src/sql/src/statements.rs @@ -22,6 +22,7 @@ pub mod describe; pub mod drop; pub mod explain; pub mod insert; +pub mod kill; mod option_map; pub mod query; pub mod set_variables; diff --git a/src/sql/src/statements/kill.rs b/src/sql/src/statements/kill.rs new file mode 100644 index 0000000000..daf0e0993b --- /dev/null +++ b/src/sql/src/statements/kill.rs @@ -0,0 +1,40 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::{Display, Formatter}; + +use serde::Serialize; +use sqlparser_derive::{Visit, VisitMut}; + +/// Arguments of `KILL` statements. +#[derive(Debug, Clone, Eq, PartialEq, Visit, VisitMut, Serialize)] +pub enum Kill { + /// Kill a remote process id. + ProcessId(String), + /// Kill MySQL connection id. + ConnectionId(u32), +} + +impl Display for Kill { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Kill::ProcessId(id) => { + write!(f, "KILL {}", id) + } + Kill::ConnectionId(id) => { + write!(f, "KILL QUERY {}", id) + } + } + } +} diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index 56492e75bc..030a2334c1 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -32,6 +32,7 @@ use crate::statements::describe::DescribeTable; use crate::statements::drop::{DropDatabase, DropFlow, DropTable, DropView}; use crate::statements::explain::Explain; use crate::statements::insert::Insert; +use crate::statements::kill::Kill; use crate::statements::query::Query; use crate::statements::set_variables::SetVariables; use crate::statements::show::{ @@ -139,7 +140,7 @@ pub enum Statement { // CLOSE CloseCursor(CloseCursor), // KILL - Kill(String), + Kill(Kill), } impl Display for Statement { diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 9bec5eb624..8cc0932240 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -648,6 +648,7 @@ pub async fn setup_mysql_server_with_user_provider( 0, opts.reject_no_database.unwrap_or(false), )), + None, ); mysql_server @@ -697,6 +698,7 @@ pub async fn setup_pg_server_with_user_provider( 0, runtime, user_provider, + None, )); pg_server