mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-08 22:32:55 +00:00
feat: handle Ctrl-C command in MySQL client (#6320)
* feat/answer-ctrl-c-in-mysql: ## Implement Connection ID-based Query Killing ### Key Changes: - **Connection ID Management:** - Added `connection_id` to `Session` and `QueryContext` in `src/session/src/lib.rs` and `src/session/src/context.rs`. - Updated `MysqlInstanceShim` and `MysqlServer` to handle `connection_id` in `src/servers/src/mysql/handler.rs` and `src/servers/src/mysql/server.rs`. - **KILL Statement Enhancements:** - Introduced `Kill` enum to handle both `ProcessId` and `ConnectionId` in `src/sql/src/statements/kill.rs`. - Updated `ParserContext` to parse `KILL QUERY <connection_id>` in `src/sql/src/parser.rs`. - Modified `StatementExecutor` to support killing queries by `connection_id` in `src/operator/src/statement/kill.rs`. - **Process Management:** - Refactored `ProcessManager` to include `connection_id` in `src/catalog/src/process_manager.rs`. - Added `kill_local_process` method for local query termination. - **Testing:** - Added tests for `KILL` statement parsing and execution in `src/sql/src/parser.rs`. ### Affected Files: - `Cargo.lock`, `Cargo.toml` - `src/catalog/src/process_manager.rs` - `src/frontend/src/instance.rs` - `src/frontend/src/stream_wrapper.rs` - `src/operator/src/statement.rs` - `src/operator/src/statement/kill.rs` - `src/servers/src/mysql/federated.rs` - `src/servers/src/mysql/handler.rs` - `src/servers/src/mysql/server.rs` - `src/servers/src/postgres.rs` - `src/session/src/context.rs` - `src/session/src/lib.rs` - `src/sql/src/parser.rs` - `src/sql/src/statements.rs` - `src/sql/src/statements/kill.rs` - `src/sql/src/statements/statement.rs` Signed-off-by: Lei, HUANG <mrsatangel@gmail.com> Conflicts: Cargo.lock Cargo.toml Signed-off-by: Lei, HUANG <mrsatangel@gmail.com> * feat/answer-ctrl-c-in-mysql: ### Enhance Process Management and Execution - **`process_manager.rs`**: Added a new method `find_processes_by_connection_id` to filter processes by connection ID, improving process management capabilities. - **`kill.rs`**: Refactored the process killing logic to utilize the new `find_processes_by_connection_id` method, streamlining the execution flow and reducing redundant checks. Signed-off-by: Lei, HUANG <mrsatangel@gmail.com> * feat/answer-ctrl-c-in-mysql: ## Commit Message ### Update Process ID Type and Refactor Code - **Change Process ID Type**: Updated the process ID type from `u64` to `u32` across multiple files to optimize memory usage. Affected files include `process_manager.rs`, `lib.rs`, `database.rs`, `instance.rs`, `server.rs`, `stream_wrapper.rs`, `kill.rs`, `federated.rs`, `handler.rs`, `server.rs`, `postgres.rs`, `mysql_server_test.rs`, `context.rs`, `lib.rs`, and `test_util.rs`. - **Remove Connection ID**: Removed the `connection_id` field and related logic from `process_manager.rs`, `lib.rs`, `instance.rs`, `server.rs`, `stream_wrapper.rs`, `kill.rs`, `federated.rs`, `handler.rs`, `server.rs`, `postgres.rs`, `mysql_server_test.rs`, `context.rs`, `lib.rs`, and `test_util.rs` to simplify the codebase. - **Refactor Process Management**: Refactored process management logic to improve clarity and maintainability in `process_manager.rs`, `kill.rs`, and `handler.rs`. - **Enhance MySQL Server Handling**: Improved MySQL server handling by integrating process management in `server.rs` and `mysql_server_test.rs`. Signed-off-by: Lei, HUANG <mrsatangel@gmail.com> * feat/answer-ctrl-c-in-mysql: ### Add Process Manager to Postgres Server - **`src/frontend/src/server.rs`**: Updated server initialization to include `process_manager`. - **`src/servers/src/postgres.rs`**: Modified `MakePostgresServerHandler` to accept `process_id` for session creation. - **`src/servers/src/postgres/server.rs`**: Integrated `process_manager` into `PostgresServer` for generating `process_id` during connection handling. - **`src/servers/tests/postgres/mod.rs`** and **`tests-integration/src/test_util.rs`**: Adjusted test server setup to accommodate optional `process_manager`. Signed-off-by: Lei, HUANG <mrsatangel@gmail.com> * feat/answer-ctrl-c-in-mysql: Update `greptime-proto` Dependency - Updated the `greptime-proto` dependency to a new revision in both `Cargo.lock` and `Cargo.toml`. - `Cargo.lock`: Changed source revision from `d75a56e05a87594fe31ad5c48525e9b2124149ba` to `fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80`. - `Cargo.toml`: Updated the `greptime-proto` git revision to match the new commit. Signed-off-by: Lei, HUANG <mrsatangel@gmail.com> --------- Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -5143,7 +5143,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "greptime-proto"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=17971523673f4fbc982510d3c9d6647ff642e16f#17971523673f4fbc982510d3c9d6647ff642e16f"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80#fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80"
|
||||
dependencies = [
|
||||
"prost 0.13.5",
|
||||
"serde",
|
||||
|
||||
@@ -134,7 +134,7 @@ etcd-client = "0.14"
|
||||
fst = "0.4.7"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "17971523673f4fbc982510d3c9d6647ff642e16f" }
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "fdcbe5f1c7c467634c90a1fd1a00a784b92a4e80" }
|
||||
hex = "0.4"
|
||||
http = "1"
|
||||
humantime = "2.1"
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
|
||||
@@ -29,6 +29,7 @@ use snafu::{ensure, OptionExt, ResultExt};
|
||||
use crate::error;
|
||||
use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
|
||||
|
||||
pub type ProcessId = u32;
|
||||
pub type ProcessManagerRef = Arc<ProcessManager>;
|
||||
|
||||
/// Query process manager.
|
||||
@@ -36,9 +37,9 @@ pub struct ProcessManager {
|
||||
/// Local frontend server address,
|
||||
server_addr: String,
|
||||
/// Next process id for local queries.
|
||||
next_id: AtomicU64,
|
||||
next_id: AtomicU32,
|
||||
/// Running process per catalog.
|
||||
catalogs: RwLock<HashMap<String, HashMap<u64, CancellableProcess>>>,
|
||||
catalogs: RwLock<HashMap<String, HashMap<ProcessId, CancellableProcess>>>,
|
||||
/// Frontend selector to locate frontend nodes.
|
||||
frontend_selector: Option<MetaClientSelector>,
|
||||
}
|
||||
@@ -65,9 +66,9 @@ impl ProcessManager {
|
||||
schemas: Vec<String>,
|
||||
query: String,
|
||||
client: String,
|
||||
id: Option<u64>,
|
||||
query_id: Option<ProcessId>,
|
||||
) -> Ticket {
|
||||
let id = id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
|
||||
let id = query_id.unwrap_or_else(|| self.next_id.fetch_add(1, Ordering::Relaxed));
|
||||
let process = ProcessInfo {
|
||||
id,
|
||||
catalog: catalog.clone(),
|
||||
@@ -96,12 +97,12 @@ impl ProcessManager {
|
||||
}
|
||||
|
||||
/// Generates the next process id.
|
||||
pub fn next_id(&self) -> u64 {
|
||||
pub fn next_id(&self) -> u32 {
|
||||
self.next_id.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// De-register a query from process list.
|
||||
pub fn deregister_query(&self, catalog: String, id: u64) {
|
||||
pub fn deregister_query(&self, catalog: String, id: ProcessId) {
|
||||
if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) {
|
||||
let process = o.get_mut().remove(&id);
|
||||
debug!("Deregister process: {:?}", process);
|
||||
@@ -159,26 +160,10 @@ impl ProcessManager {
|
||||
&self,
|
||||
server_addr: String,
|
||||
catalog: String,
|
||||
id: u64,
|
||||
id: ProcessId,
|
||||
) -> error::Result<bool> {
|
||||
if server_addr == self.server_addr {
|
||||
if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
|
||||
if let Some(process) = catalogs.remove(&id) {
|
||||
process.handle.cancel();
|
||||
info!(
|
||||
"Killed process, catalog: {}, id: {:?}",
|
||||
process.process.catalog, process.process.id
|
||||
);
|
||||
PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
|
||||
Ok(true)
|
||||
} else {
|
||||
debug!("Failed to kill process, id not found: {}", id);
|
||||
Ok(false)
|
||||
}
|
||||
} else {
|
||||
debug!("Failed to kill process, catalog not found: {}", catalog);
|
||||
Ok(false)
|
||||
}
|
||||
self.kill_local_process(catalog, id).await
|
||||
} else {
|
||||
let mut nodes = self
|
||||
.frontend_selector
|
||||
@@ -204,12 +189,33 @@ impl ProcessManager {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
/// Kills local query with provided catalog and id.
|
||||
pub async fn kill_local_process(&self, catalog: String, id: ProcessId) -> error::Result<bool> {
|
||||
if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) {
|
||||
if let Some(process) = catalogs.remove(&id) {
|
||||
process.handle.cancel();
|
||||
info!(
|
||||
"Killed process, catalog: {}, id: {:?}",
|
||||
process.process.catalog, process.process.id
|
||||
);
|
||||
PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc();
|
||||
Ok(true)
|
||||
} else {
|
||||
debug!("Failed to kill process, id not found: {}", id);
|
||||
Ok(false)
|
||||
}
|
||||
} else {
|
||||
debug!("Failed to kill process, catalog not found: {}", catalog);
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Ticket {
|
||||
pub(crate) catalog: String,
|
||||
pub(crate) manager: ProcessManagerRef,
|
||||
pub(crate) id: u64,
|
||||
pub(crate) id: ProcessId,
|
||||
pub cancellation_handle: Arc<CancellationHandle>,
|
||||
}
|
||||
|
||||
@@ -323,7 +329,7 @@ mod tests {
|
||||
assert_eq!(running_processes.len(), 2);
|
||||
|
||||
// Verify both processes are present
|
||||
let ids: Vec<u64> = running_processes.iter().map(|p| p.id).collect();
|
||||
let ids: Vec<u32> = running_processes.iter().map(|p| p.id).collect();
|
||||
assert!(ids.contains(&ticket1.id));
|
||||
assert!(ids.contains(&ticket2.id));
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ pub mod selector;
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct DisplayProcessId {
|
||||
pub server_addr: String,
|
||||
pub id: u64,
|
||||
pub id: u32,
|
||||
}
|
||||
|
||||
impl Display for DisplayProcessId {
|
||||
@@ -44,7 +44,7 @@ impl TryFrom<&str> for DisplayProcessId {
|
||||
let id = split
|
||||
.next()
|
||||
.context(error::ParseProcessIdSnafu { s: value })?;
|
||||
let id = u64::from_str(id)
|
||||
let id = u32::from_str(id)
|
||||
.ok()
|
||||
.context(error::ParseProcessIdSnafu { s: value })?;
|
||||
Ok(DisplayProcessId { server_addr, id })
|
||||
|
||||
@@ -18,7 +18,7 @@ use std::sync::Arc;
|
||||
use common_query::error::Result;
|
||||
use common_query::prelude::{Signature, Volatility};
|
||||
use datatypes::prelude::{ConcreteDataType, ScalarVector};
|
||||
use datatypes::vectors::{StringVector, UInt64Vector, VectorRef};
|
||||
use datatypes::vectors::{StringVector, UInt32Vector, VectorRef};
|
||||
use derive_more::Display;
|
||||
|
||||
use crate::function::{Function, FunctionContext};
|
||||
@@ -144,7 +144,7 @@ impl Function for PgBackendPidFunction {
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
let pid = func_ctx.query_ctx.process_id();
|
||||
|
||||
Ok(Arc::new(UInt64Vector::from_slice([pid])) as _)
|
||||
Ok(Arc::new(UInt32Vector::from_slice([pid])) as _)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ impl Function for ConnectionIdFunction {
|
||||
fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
|
||||
let pid = func_ctx.query_ctx.process_id();
|
||||
|
||||
Ok(Arc::new(UInt64Vector::from_slice([pid])) as _)
|
||||
Ok(Arc::new(UInt32Vector::from_slice([pid])) as _)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -235,6 +235,7 @@ where
|
||||
opts.keep_alive.as_secs(),
|
||||
opts.reject_no_database.unwrap_or(false),
|
||||
)),
|
||||
Some(instance.process_manager().clone()),
|
||||
);
|
||||
handlers.insert((mysql_server, mysql_addr));
|
||||
}
|
||||
@@ -257,6 +258,7 @@ where
|
||||
opts.keep_alive.as_secs(),
|
||||
common_runtime::global_runtime(),
|
||||
user_provider.clone(),
|
||||
Some(self.instance.process_manager().clone()),
|
||||
)) as Box<dyn Server>;
|
||||
|
||||
handlers.insert((pg_server, pg_addr));
|
||||
|
||||
@@ -369,7 +369,7 @@ impl StatementExecutor {
|
||||
Statement::ShowSearchPath(_) => self.show_search_path(query_ctx).await,
|
||||
Statement::Use(db) => self.use_database(db, query_ctx).await,
|
||||
Statement::Admin(admin) => self.execute_admin_command(admin, query_ctx).await,
|
||||
Statement::Kill(id) => self.execute_kill(query_ctx, id).await,
|
||||
Statement::Kill(kill) => self.execute_kill(query_ctx, kill).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use catalog::process_manager::ProcessManagerRef;
|
||||
use common_frontend::DisplayProcessId;
|
||||
use common_query::Output;
|
||||
use common_telemetry::error;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ResultExt;
|
||||
use sql::statements::kill::Kill;
|
||||
|
||||
use crate::error;
|
||||
use crate::statement::StatementExecutor;
|
||||
@@ -25,22 +27,51 @@ impl StatementExecutor {
|
||||
pub async fn execute_kill(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
process_id: String,
|
||||
) -> crate::error::Result<Output> {
|
||||
kill: Kill,
|
||||
) -> error::Result<Output> {
|
||||
let Some(process_manager) = self.process_manager.as_ref() else {
|
||||
error!("Process manager is not initialized");
|
||||
return error::ProcessManagerMissingSnafu.fail();
|
||||
};
|
||||
|
||||
let succ = match kill {
|
||||
Kill::ProcessId(process_id) => {
|
||||
self.kill_process_id(process_manager, query_ctx, process_id)
|
||||
.await?
|
||||
}
|
||||
Kill::ConnectionId(conn_id) => {
|
||||
self.kill_connection_id(process_manager, query_ctx, conn_id)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
Ok(Output::new_with_affected_rows(if succ { 1 } else { 0 }))
|
||||
}
|
||||
|
||||
/// Handles `KILL <PROCESS_ID>` statements.
|
||||
async fn kill_process_id(
|
||||
&self,
|
||||
pm: &ProcessManagerRef,
|
||||
query_ctx: QueryContextRef,
|
||||
process_id: String,
|
||||
) -> error::Result<bool> {
|
||||
let display_id = DisplayProcessId::try_from(process_id.as_str())
|
||||
.map_err(|_| error::InvalidProcessIdSnafu { id: process_id }.build())?;
|
||||
|
||||
let current_user_catalog = query_ctx.current_catalog().to_string();
|
||||
process_manager
|
||||
.kill_process(display_id.server_addr, current_user_catalog, display_id.id)
|
||||
pm.kill_process(display_id.server_addr, current_user_catalog, display_id.id)
|
||||
.await
|
||||
.context(error::CatalogSnafu)?;
|
||||
.context(error::CatalogSnafu)
|
||||
}
|
||||
|
||||
Ok(Output::new_with_affected_rows(0))
|
||||
/// Handles MySQL `KILL QUERY <CONNECTION_ID>` statements.
|
||||
pub async fn kill_connection_id(
|
||||
&self,
|
||||
pm: &ProcessManagerRef,
|
||||
query_ctx: QueryContextRef,
|
||||
connection_id: u32,
|
||||
) -> error::Result<bool> {
|
||||
pm.kill_local_process(query_ctx.current_catalog().to_string(), connection_id)
|
||||
.await
|
||||
.context(error::CatalogSnafu)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,7 @@ pub struct MysqlInstanceShim {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
|
||||
prepared_stmts_counter: AtomicU32,
|
||||
process_id: u32,
|
||||
}
|
||||
|
||||
impl MysqlInstanceShim {
|
||||
@@ -89,6 +90,7 @@ impl MysqlInstanceShim {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
client_addr: SocketAddr,
|
||||
process_id: u32,
|
||||
) -> MysqlInstanceShim {
|
||||
// init a random salt
|
||||
let mut bs = vec![0u8; 20];
|
||||
@@ -110,12 +112,12 @@ impl MysqlInstanceShim {
|
||||
Some(client_addr),
|
||||
Channel::Mysql,
|
||||
Default::default(),
|
||||
// TODO(sunng87): generate process id properly
|
||||
0,
|
||||
process_id,
|
||||
)),
|
||||
user_provider,
|
||||
prepared_stmts: Default::default(),
|
||||
prepared_stmts_counter: AtomicU32::new(1),
|
||||
process_id,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,6 +343,10 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
|
||||
}
|
||||
|
||||
fn connect_id(&self) -> u32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
fn default_auth_plugin(&self) -> &str {
|
||||
self.auth_plugin()
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use auth::UserProviderRef;
|
||||
use catalog::process_manager::ProcessManagerRef;
|
||||
use common_runtime::runtime::RuntimeTrait;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::{debug, warn};
|
||||
@@ -112,6 +113,7 @@ pub struct MysqlServer {
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
bind_addr: Option<SocketAddr>,
|
||||
process_manager: Option<ProcessManagerRef>,
|
||||
}
|
||||
|
||||
impl MysqlServer {
|
||||
@@ -119,16 +121,23 @@ impl MysqlServer {
|
||||
io_runtime: Runtime,
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
process_manager: Option<ProcessManagerRef>,
|
||||
) -> Box<dyn Server> {
|
||||
Box::new(MysqlServer {
|
||||
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
|
||||
spawn_ref,
|
||||
spawn_config,
|
||||
bind_addr: None,
|
||||
process_manager,
|
||||
})
|
||||
}
|
||||
|
||||
fn accept(&self, io_runtime: Runtime, stream: AbortableStream) -> impl Future<Output = ()> {
|
||||
fn accept(
|
||||
&self,
|
||||
io_runtime: Runtime,
|
||||
stream: AbortableStream,
|
||||
process_manager: Option<ProcessManagerRef>,
|
||||
) -> impl Future<Output = ()> {
|
||||
let spawn_ref = self.spawn_ref.clone();
|
||||
let spawn_config = self.spawn_config.clone();
|
||||
|
||||
@@ -136,7 +145,7 @@ impl MysqlServer {
|
||||
let spawn_ref = spawn_ref.clone();
|
||||
let spawn_config = spawn_config.clone();
|
||||
let io_runtime = io_runtime.clone();
|
||||
|
||||
let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(8);
|
||||
async move {
|
||||
match tcp_stream {
|
||||
Err(e) => warn!(e; "Broken pipe"), // IoError doesn't impl ErrorExt.
|
||||
@@ -146,7 +155,7 @@ impl MysqlServer {
|
||||
}
|
||||
io_runtime.spawn(async move {
|
||||
if let Err(error) =
|
||||
Self::handle(io_stream, spawn_ref, spawn_config).await
|
||||
Self::handle(io_stream, spawn_ref, spawn_config, process_id).await
|
||||
{
|
||||
warn!(error; "Unexpected error when handling TcpStream");
|
||||
};
|
||||
@@ -161,10 +170,11 @@ impl MysqlServer {
|
||||
stream: TcpStream,
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
process_id: u32,
|
||||
) -> Result<()> {
|
||||
debug!("MySQL connection coming from: {}", stream.peer_addr()?);
|
||||
crate::metrics::METRIC_MYSQL_CONNECTIONS.inc();
|
||||
if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await {
|
||||
if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config, process_id).await {
|
||||
if let Error::InternalIo { error } = &e
|
||||
&& error.kind() == std::io::ErrorKind::ConnectionAborted
|
||||
{
|
||||
@@ -184,11 +194,13 @@ impl MysqlServer {
|
||||
stream: TcpStream,
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
process_id: u32,
|
||||
) -> Result<()> {
|
||||
let mut shim = MysqlInstanceShim::create(
|
||||
spawn_ref.query_handler(),
|
||||
spawn_ref.user_provider(),
|
||||
stream.peer_addr()?,
|
||||
process_id,
|
||||
);
|
||||
let (mut r, w) = stream.into_split();
|
||||
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
|
||||
@@ -230,7 +242,11 @@ impl Server for MysqlServer {
|
||||
.await?;
|
||||
let io_runtime = self.base_server.io_runtime();
|
||||
|
||||
let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
|
||||
let join_handle = common_runtime::spawn_global(self.accept(
|
||||
io_runtime,
|
||||
stream,
|
||||
self.process_manager.clone(),
|
||||
));
|
||||
self.base_server.start_with(join_handle).await?;
|
||||
|
||||
self.bind_addr = Some(addr);
|
||||
|
||||
@@ -119,9 +119,13 @@ impl PgWireServerHandlers for PostgresServerHandler {
|
||||
}
|
||||
|
||||
impl MakePostgresServerHandler {
|
||||
fn make(&self, addr: Option<SocketAddr>) -> PostgresServerHandler {
|
||||
// TODO(sunng87): generate pid from process manager
|
||||
let session = Arc::new(Session::new(addr, Channel::Postgres, Default::default(), 0));
|
||||
fn make(&self, addr: Option<SocketAddr>, process_id: u32) -> PostgresServerHandler {
|
||||
let session = Arc::new(Session::new(
|
||||
addr,
|
||||
Channel::Postgres,
|
||||
Default::default(),
|
||||
process_id,
|
||||
));
|
||||
let handler = PostgresServerHandlerInner {
|
||||
query_handler: self.query_handler.clone(),
|
||||
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
|
||||
|
||||
@@ -18,6 +18,7 @@ use std::sync::Arc;
|
||||
|
||||
use ::auth::UserProviderRef;
|
||||
use async_trait::async_trait;
|
||||
use catalog::process_manager::ProcessManagerRef;
|
||||
use common_runtime::runtime::RuntimeTrait;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::{debug, warn};
|
||||
@@ -37,6 +38,7 @@ pub struct PostgresServer {
|
||||
tls_server_config: Arc<ReloadableTlsServerConfig>,
|
||||
keep_alive_secs: u64,
|
||||
bind_addr: Option<SocketAddr>,
|
||||
process_manager: Option<ProcessManagerRef>,
|
||||
}
|
||||
|
||||
impl PostgresServer {
|
||||
@@ -48,6 +50,7 @@ impl PostgresServer {
|
||||
keep_alive_secs: u64,
|
||||
io_runtime: Runtime,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
process_manager: Option<ProcessManagerRef>,
|
||||
) -> PostgresServer {
|
||||
let make_handler = Arc::new(
|
||||
MakePostgresServerHandlerBuilder::default()
|
||||
@@ -63,6 +66,7 @@ impl PostgresServer {
|
||||
tls_server_config,
|
||||
keep_alive_secs,
|
||||
bind_addr: None,
|
||||
process_manager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,12 +77,12 @@ impl PostgresServer {
|
||||
) -> impl Future<Output = ()> {
|
||||
let handler_maker = self.make_handler.clone();
|
||||
let tls_server_config = self.tls_server_config.clone();
|
||||
let process_manager = self.process_manager.clone();
|
||||
accepting_stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
|
||||
let tls_acceptor = tls_server_config.get_server_config().map(TlsAcceptor::from);
|
||||
|
||||
let handler_maker = handler_maker.clone();
|
||||
let process_id = process_manager.as_ref().map(|p| p.next_id()).unwrap_or(0);
|
||||
|
||||
async move {
|
||||
match tcp_stream {
|
||||
@@ -97,7 +101,7 @@ impl PostgresServer {
|
||||
|
||||
let _handle = io_runtime.spawn(async move {
|
||||
crate::metrics::METRIC_POSTGRES_CONNECTIONS.inc();
|
||||
let pg_handler = Arc::new(handler_maker.make(addr));
|
||||
let pg_handler = Arc::new(handler_maker.make(addr, process_id));
|
||||
let r =
|
||||
process_socket(io_stream, tls_acceptor.clone(), pg_handler).await;
|
||||
crate::metrics::METRIC_POSTGRES_CONNECTIONS.dec();
|
||||
|
||||
@@ -73,6 +73,7 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result<Box<dyn S
|
||||
0,
|
||||
opts.reject_no_database,
|
||||
)),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ fn create_postgres_server(
|
||||
0,
|
||||
io_runtime,
|
||||
user_provider,
|
||||
None,
|
||||
)))
|
||||
}
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ pub struct QueryContext {
|
||||
channel: Channel,
|
||||
/// Process id for managing on-going queries
|
||||
#[builder(default)]
|
||||
process_id: u64,
|
||||
process_id: u32,
|
||||
/// Connection information
|
||||
#[builder(default)]
|
||||
conn_info: ConnInfo,
|
||||
@@ -439,7 +439,7 @@ impl QueryContext {
|
||||
.copied()
|
||||
}
|
||||
|
||||
pub fn process_id(&self) -> u64 {
|
||||
pub fn process_id(&self) -> u32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
|
||||
@@ -41,8 +41,8 @@ pub struct Session {
|
||||
mutable_inner: Arc<RwLock<MutableInner>>,
|
||||
conn_info: ConnInfo,
|
||||
configuration_variables: Arc<ConfigurationVariables>,
|
||||
// the connection id to use when killing the query
|
||||
process_id: u64,
|
||||
// the process id to use when killing the query
|
||||
process_id: u32,
|
||||
}
|
||||
|
||||
pub type SessionRef = Arc<Session>;
|
||||
@@ -77,7 +77,7 @@ impl Session {
|
||||
addr: Option<SocketAddr>,
|
||||
channel: Channel,
|
||||
configuration_variables: ConfigurationVariables,
|
||||
process_id: u64,
|
||||
process_id: u32,
|
||||
) -> Self {
|
||||
Session {
|
||||
catalog: RwLock::new(DEFAULT_CATALOG_NAME.into()),
|
||||
@@ -152,7 +152,7 @@ impl Session {
|
||||
build_db_string(&self.catalog(), &self.schema())
|
||||
}
|
||||
|
||||
pub fn process_id(&self) -> u64 {
|
||||
pub fn process_id(&self) -> u32 {
|
||||
self.process_id
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use snafu::ResultExt;
|
||||
use sqlparser::ast::{Ident, Query};
|
||||
use sqlparser::ast::{Ident, Query, Value};
|
||||
use sqlparser::dialect::Dialect;
|
||||
use sqlparser::keywords::Keyword;
|
||||
use sqlparser::parser::{Parser, ParserError, ParserOptions};
|
||||
@@ -22,6 +24,7 @@ use sqlparser::tokenizer::{Token, TokenWithSpan};
|
||||
use crate::ast::{Expr, ObjectName};
|
||||
use crate::error::{self, Result, SyntaxSnafu};
|
||||
use crate::parsers::tql_parser;
|
||||
use crate::statements::kill::Kill;
|
||||
use crate::statements::statement::Statement;
|
||||
use crate::statements::transform_statements;
|
||||
|
||||
@@ -190,14 +193,43 @@ impl ParserContext<'_> {
|
||||
|
||||
Keyword::KILL => {
|
||||
let _ = self.parser.next_token();
|
||||
let process_id_ident =
|
||||
self.parser.parse_literal_string().with_context(|_| {
|
||||
error::UnexpectedSnafu {
|
||||
expected: "process id string literal",
|
||||
actual: self.peek_token_as_string(),
|
||||
let kill = if self.parser.parse_keyword(Keyword::QUERY) {
|
||||
// MySQL KILL QUERY <connection id> statements
|
||||
let connection_id_exp =
|
||||
self.parser.parse_number_value().with_context(|_| {
|
||||
error::UnexpectedSnafu {
|
||||
expected: "MySQL numeric connection id",
|
||||
actual: self.peek_token_as_string(),
|
||||
}
|
||||
})?;
|
||||
let Value::Number(s, _) = connection_id_exp else {
|
||||
return error::UnexpectedTokenSnafu {
|
||||
expected: "MySQL numeric connection id",
|
||||
actual: connection_id_exp.to_string(),
|
||||
}
|
||||
.fail();
|
||||
};
|
||||
|
||||
let connection_id = u32::from_str(&s).map_err(|_| {
|
||||
error::UnexpectedTokenSnafu {
|
||||
expected: "MySQL numeric connection id",
|
||||
actual: s,
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
Ok(Statement::Kill(process_id_ident))
|
||||
Kill::ConnectionId(connection_id)
|
||||
} else {
|
||||
let process_id_ident =
|
||||
self.parser.parse_literal_string().with_context(|_| {
|
||||
error::UnexpectedSnafu {
|
||||
expected: "process id string literal",
|
||||
actual: self.peek_token_as_string(),
|
||||
}
|
||||
})?;
|
||||
Kill::ProcessId(process_id_ident)
|
||||
};
|
||||
|
||||
Ok(Statement::Kill(kill))
|
||||
}
|
||||
|
||||
_ => self.unsupported(self.peek_token_as_string()),
|
||||
@@ -440,4 +472,191 @@ mod tests {
|
||||
let stmt_name = ParserContext::parse_mysql_deallocate_stmt(sql, &MySqlDialect {}).unwrap();
|
||||
assert_eq!(stmt_name, "stmt2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_kill_query_statement() {
|
||||
use crate::statements::kill::Kill;
|
||||
|
||||
// Test MySQL-style KILL QUERY with connection ID
|
||||
let sql = "KILL QUERY 123";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ConnectionId(connection_id)) => {
|
||||
assert_eq!(*connection_id, 123);
|
||||
}
|
||||
_ => panic!("Expected Kill::ConnectionId statement"),
|
||||
}
|
||||
|
||||
// Test with larger connection ID
|
||||
let sql = "KILL QUERY 999999";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ConnectionId(connection_id)) => {
|
||||
assert_eq!(*connection_id, 999999);
|
||||
}
|
||||
_ => panic!("Expected Kill::ConnectionId statement"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_kill_process_statement() {
|
||||
use crate::statements::kill::Kill;
|
||||
|
||||
// Test KILL with process ID string
|
||||
let sql = "KILL 'process-123'";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ProcessId(process_id)) => {
|
||||
assert_eq!(process_id, "process-123");
|
||||
}
|
||||
_ => panic!("Expected Kill::ProcessId statement"),
|
||||
}
|
||||
|
||||
// Test with double quotes
|
||||
let sql = "KILL \"process-456\"";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ProcessId(process_id)) => {
|
||||
assert_eq!(process_id, "process-456");
|
||||
}
|
||||
_ => panic!("Expected Kill::ProcessId statement"),
|
||||
}
|
||||
|
||||
// Test with UUID-like process ID
|
||||
let sql = "KILL 'f47ac10b-58cc-4372-a567-0e02b2c3d479'";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ProcessId(process_id)) => {
|
||||
assert_eq!(process_id, "f47ac10b-58cc-4372-a567-0e02b2c3d479");
|
||||
}
|
||||
_ => panic!("Expected Kill::ProcessId statement"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_kill_statement_errors() {
|
||||
// Test KILL QUERY without connection ID
|
||||
let sql = "KILL QUERY";
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test KILL QUERY with non-numeric connection ID
|
||||
let sql = "KILL QUERY 'not-a-number'";
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test KILL without any argument
|
||||
let sql = "KILL";
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test KILL QUERY with connection ID that's too large for u32
|
||||
let sql = "KILL QUERY 4294967296"; // u32::MAX + 1
|
||||
let result =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_kill_statement_edge_cases() {
|
||||
use crate::statements::kill::Kill;
|
||||
|
||||
// Test KILL QUERY with zero connection ID
|
||||
let sql = "KILL QUERY 0";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ConnectionId(connection_id)) => {
|
||||
assert_eq!(*connection_id, 0);
|
||||
}
|
||||
_ => panic!("Expected Kill::ConnectionId statement"),
|
||||
}
|
||||
|
||||
// Test KILL QUERY with maximum u32 value
|
||||
let sql = "KILL QUERY 4294967295"; // u32::MAX
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ConnectionId(connection_id)) => {
|
||||
assert_eq!(*connection_id, 4294967295);
|
||||
}
|
||||
_ => panic!("Expected Kill::ConnectionId statement"),
|
||||
}
|
||||
|
||||
// Test KILL with empty string process ID
|
||||
let sql = "KILL ''";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ProcessId(process_id)) => {
|
||||
assert_eq!(process_id, "");
|
||||
}
|
||||
_ => panic!("Expected Kill::ProcessId statement"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_parse_kill_statement_case_insensitive() {
|
||||
use crate::statements::kill::Kill;
|
||||
|
||||
// Test lowercase
|
||||
let sql = "kill query 123";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ConnectionId(connection_id)) => {
|
||||
assert_eq!(*connection_id, 123);
|
||||
}
|
||||
_ => panic!("Expected Kill::ConnectionId statement"),
|
||||
}
|
||||
|
||||
// Test mixed case
|
||||
let sql = "Kill Query 456";
|
||||
let statements =
|
||||
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(statements.len(), 1);
|
||||
match &statements[0] {
|
||||
Statement::Kill(Kill::ConnectionId(connection_id)) => {
|
||||
assert_eq!(*connection_id, 456);
|
||||
}
|
||||
_ => panic!("Expected Kill::ConnectionId statement"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ pub mod describe;
|
||||
pub mod drop;
|
||||
pub mod explain;
|
||||
pub mod insert;
|
||||
pub mod kill;
|
||||
mod option_map;
|
||||
pub mod query;
|
||||
pub mod set_variables;
|
||||
|
||||
40
src/sql/src/statements/kill.rs
Normal file
40
src/sql/src/statements/kill.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
use serde::Serialize;
|
||||
use sqlparser_derive::{Visit, VisitMut};
|
||||
|
||||
/// Arguments of `KILL` statements.
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Visit, VisitMut, Serialize)]
|
||||
pub enum Kill {
|
||||
/// Kill a remote process id.
|
||||
ProcessId(String),
|
||||
/// Kill MySQL connection id.
|
||||
ConnectionId(u32),
|
||||
}
|
||||
|
||||
impl Display for Kill {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Kill::ProcessId(id) => {
|
||||
write!(f, "KILL {}", id)
|
||||
}
|
||||
Kill::ConnectionId(id) => {
|
||||
write!(f, "KILL QUERY {}", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,6 +32,7 @@ use crate::statements::describe::DescribeTable;
|
||||
use crate::statements::drop::{DropDatabase, DropFlow, DropTable, DropView};
|
||||
use crate::statements::explain::Explain;
|
||||
use crate::statements::insert::Insert;
|
||||
use crate::statements::kill::Kill;
|
||||
use crate::statements::query::Query;
|
||||
use crate::statements::set_variables::SetVariables;
|
||||
use crate::statements::show::{
|
||||
@@ -139,7 +140,7 @@ pub enum Statement {
|
||||
// CLOSE
|
||||
CloseCursor(CloseCursor),
|
||||
// KILL <process>
|
||||
Kill(String),
|
||||
Kill(Kill),
|
||||
}
|
||||
|
||||
impl Display for Statement {
|
||||
|
||||
@@ -648,6 +648,7 @@ pub async fn setup_mysql_server_with_user_provider(
|
||||
0,
|
||||
opts.reject_no_database.unwrap_or(false),
|
||||
)),
|
||||
None,
|
||||
);
|
||||
|
||||
mysql_server
|
||||
@@ -697,6 +698,7 @@ pub async fn setup_pg_server_with_user_provider(
|
||||
0,
|
||||
runtime,
|
||||
user_provider,
|
||||
None,
|
||||
));
|
||||
|
||||
pg_server
|
||||
|
||||
Reference in New Issue
Block a user