diff --git a/Cargo.lock b/Cargo.lock index 792907d9be..3fbb42e9c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1621,6 +1621,7 @@ dependencies = [ "cache", "catalog", "chrono", + "common-base", "common-catalog", "common-error", "common-frontend", @@ -4755,6 +4756,7 @@ dependencies = [ "substrait 0.15.0", "table", "tokio", + "tokio-util", "toml 0.8.19", "tonic 0.12.3", "tower 0.5.2", @@ -5141,7 +5143,7 @@ dependencies = [ [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=5f6119ac7952878d39dcde0343c4bf828d18ffc8#5f6119ac7952878d39dcde0343c4bf828d18ffc8" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=f0913f179ee1d2ce428f8b85a9ea12b5f69ad636#f0913f179ee1d2ce428f8b85a9ea12b5f69ad636" dependencies = [ "prost 0.13.5", "serde", @@ -8419,6 +8421,7 @@ dependencies = [ "common-catalog", "common-datasource", "common-error", + "common-frontend", "common-function", "common-grpc", "common-grpc-expr", diff --git a/Cargo.toml b/Cargo.toml index f27d1fa586..409b797685 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -134,7 +134,7 @@ etcd-client = "0.14" fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "5f6119ac7952878d39dcde0343c4bf828d18ffc8" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "f0913f179ee1d2ce428f8b85a9ea12b5f69ad636" } hex = "0.4" http = "1" humantime = "2.1" diff --git a/src/catalog/Cargo.toml b/src/catalog/Cargo.toml index e5d8887357..c7e2782c0e 100644 --- a/src/catalog/Cargo.toml +++ b/src/catalog/Cargo.toml @@ -17,6 +17,7 @@ arrow-schema.workspace = true async-stream.workspace = true async-trait.workspace = true bytes.workspace = true +common-base.workspace = true common-catalog.workspace = true common-error.workspace = true common-frontend.workspace = true diff --git a/src/catalog/src/error.rs b/src/catalog/src/error.rs index 165f99f28f..30edab94df 100644 --- a/src/catalog/src/error.rs +++ b/src/catalog/src/error.rs @@ -278,12 +278,25 @@ pub enum Error { location: Location, }, - #[snafu(display("Failed to list frontend nodes"))] - ListProcess { + #[snafu(display("Failed to invoke frontend services"))] + InvokeFrontend { source: common_frontend::error::Error, #[snafu(implicit)] location: Location, }, + + #[snafu(display("Meta client is not provided"))] + MetaClientMissing { + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to find frontend node: {}", addr))] + FrontendNotFound { + addr: String, + #[snafu(implicit)] + location: Location, + }, } impl Error { @@ -352,7 +365,10 @@ impl ErrorExt for Error { Error::GetViewCache { source, .. } | Error::GetTableCache { source, .. } => { source.status_code() } - Error::ListProcess { source, .. } => source.status_code(), + Error::InvokeFrontend { source, .. } => source.status_code(), + Error::FrontendNotFound { .. } | Error::MetaClientMissing { .. } => { + StatusCode::Unexpected + } } } diff --git a/src/catalog/src/metrics.rs b/src/catalog/src/metrics.rs index 77063d64ba..635d2bdf0d 100644 --- a/src/catalog/src/metrics.rs +++ b/src/catalog/src/metrics.rs @@ -34,4 +34,20 @@ lazy_static! { register_histogram!("greptime_catalog_kv_get", "catalog kv get").unwrap(); pub static ref METRIC_CATALOG_KV_BATCH_GET: Histogram = register_histogram!("greptime_catalog_kv_batch_get", "catalog kv batch get").unwrap(); + + /// Count of running process in each catalog. + pub static ref PROCESS_LIST_COUNT: IntGaugeVec = register_int_gauge_vec!( + "greptime_process_list_count", + "Running process count per catalog", + &["catalog"] + ) + .unwrap(); + + /// Count of killed process in each catalog. + pub static ref PROCESS_KILL_COUNT: IntCounterVec = register_int_counter_vec!( + "greptime_process_kill_count", + "Completed kill process requests count", + &["catalog"] + ) + .unwrap(); } diff --git a/src/catalog/src/process_manager.rs b/src/catalog/src/process_manager.rs index fa79848dde..3f7950da7d 100644 --- a/src/catalog/src/process_manager.rs +++ b/src/catalog/src/process_manager.rs @@ -14,24 +14,32 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, RwLock}; -use api::v1::frontend::{ListProcessRequest, ProcessInfo}; +use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo}; +use common_base::cancellation::CancellationHandle; use common_frontend::selector::{FrontendSelector, MetaClientSelector}; use common_telemetry::{debug, info}; use common_time::util::current_time_millis; use meta_client::MetaClientRef; -use snafu::ResultExt; +use snafu::{ensure, OptionExt, ResultExt}; use crate::error; +use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT}; pub type ProcessManagerRef = Arc; +/// Query process manager. pub struct ProcessManager { + /// Local frontend server address, server_addr: String, + /// Next process id for local queries. next_id: AtomicU64, - catalogs: RwLock>>, + /// Running process per catalog. + catalogs: RwLock>>, + /// Frontend selector to locate frontend nodes. frontend_selector: Option, } @@ -50,6 +58,7 @@ impl ProcessManager { impl ProcessManager { /// Registers a submitted query. Use the provided id if present. + #[must_use] pub fn register_query( self: &Arc, catalog: String, @@ -68,16 +77,21 @@ impl ProcessManager { client, frontend: self.server_addr.clone(), }; + let cancellation_handle = Arc::new(CancellationHandle::default()); + let cancellable_process = CancellableProcess::new(cancellation_handle.clone(), process); + self.catalogs .write() .unwrap() .entry(catalog.clone()) .or_default() - .insert(id, process); + .insert(id, cancellable_process); + Ticket { catalog, manager: self.clone(), id, + cancellation_handle, } } @@ -91,30 +105,25 @@ impl ProcessManager { if let Entry::Occupied(mut o) = self.catalogs.write().unwrap().entry(catalog) { let process = o.get_mut().remove(&id); debug!("Deregister process: {:?}", process); - if o.get_mut().is_empty() { + if o.get().is_empty() { o.remove(); } } } - pub fn deregister_all_queries(&self) { - self.catalogs.write().unwrap().clear(); - info!("All queries on {} has been deregistered", self.server_addr); - } - /// List local running processes in given catalog. pub fn local_processes(&self, catalog: Option<&str>) -> error::Result> { let catalogs = self.catalogs.read().unwrap(); let result = if let Some(catalog) = catalog { if let Some(catalogs) = catalogs.get(catalog) { - catalogs.values().cloned().collect() + catalogs.values().map(|p| p.process.clone()).collect() } else { vec![] } } else { catalogs .values() - .flat_map(|v| v.values().cloned()) + .flat_map(|v| v.values().map(|p| p.process.clone())) .collect() }; Ok(result) @@ -129,14 +138,14 @@ impl ProcessManager { let frontends = remote_frontend_selector .select(|node| node.peer.addr != self.server_addr) .await - .context(error::ListProcessSnafu)?; + .context(error::InvokeFrontendSnafu)?; for mut f in frontends { processes.extend( f.list_process(ListProcessRequest { catalog: catalog.unwrap_or_default().to_string(), }) .await - .context(error::ListProcessSnafu)? + .context(error::InvokeFrontendSnafu)? .processes, ); } @@ -144,12 +153,64 @@ impl ProcessManager { processes.extend(self.local_processes(catalog)?); Ok(processes) } + + /// Kills query with provided catalog and id. + pub async fn kill_process( + &self, + server_addr: String, + catalog: String, + id: u64, + ) -> error::Result { + if server_addr == self.server_addr { + if let Some(catalogs) = self.catalogs.write().unwrap().get_mut(&catalog) { + if let Some(process) = catalogs.remove(&id) { + process.handle.cancel(); + info!( + "Killed process, catalog: {}, id: {:?}", + process.process.catalog, process.process.id + ); + PROCESS_KILL_COUNT.with_label_values(&[&catalog]).inc(); + Ok(true) + } else { + debug!("Failed to kill process, id not found: {}", id); + Ok(false) + } + } else { + debug!("Failed to kill process, catalog not found: {}", catalog); + Ok(false) + } + } else { + let mut nodes = self + .frontend_selector + .as_ref() + .context(error::MetaClientMissingSnafu)? + .select(|node| node.peer.addr == server_addr) + .await + .context(error::InvokeFrontendSnafu)?; + ensure!( + !nodes.is_empty(), + error::FrontendNotFoundSnafu { addr: server_addr } + ); + + let request = KillProcessRequest { + server_addr, + catalog, + process_id: id, + }; + nodes[0] + .kill_process(request) + .await + .context(error::InvokeFrontendSnafu)?; + Ok(true) + } + } } pub struct Ticket { pub(crate) catalog: String, pub(crate) manager: ProcessManagerRef, pub(crate) id: u64, + pub cancellation_handle: Arc, } impl Drop for Ticket { @@ -159,6 +220,37 @@ impl Drop for Ticket { } } +struct CancellableProcess { + handle: Arc, + process: ProcessInfo, +} + +impl Drop for CancellableProcess { + fn drop(&mut self) { + PROCESS_LIST_COUNT + .with_label_values(&[&self.process.catalog]) + .dec(); + } +} + +impl CancellableProcess { + fn new(handle: Arc, process: ProcessInfo) -> Self { + PROCESS_LIST_COUNT + .with_label_values(&[&process.catalog]) + .inc(); + Self { handle, process } + } +} + +impl Debug for CancellableProcess { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CancellableProcess") + .field("cancelled", &self.handle.is_cancelled()) + .field("process", &self.process) + .finish() + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -185,4 +277,212 @@ mod tests { drop(ticket); assert_eq!(process_manager.local_processes(None).unwrap().len(), 0); } + + #[tokio::test] + async fn test_register_query_with_custom_id() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + let custom_id = 12345; + + let ticket = process_manager.clone().register_query( + "public".to_string(), + vec!["test".to_string()], + "SELECT * FROM table".to_string(), + "client1".to_string(), + Some(custom_id), + ); + + assert_eq!(ticket.id, custom_id); + + let running_processes = process_manager.local_processes(None).unwrap(); + assert_eq!(running_processes.len(), 1); + assert_eq!(running_processes[0].id, custom_id); + assert_eq!(&running_processes[0].client, "client1"); + } + + #[tokio::test] + async fn test_multiple_queries_same_catalog() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + + let ticket1 = process_manager.clone().register_query( + "public".to_string(), + vec!["schema1".to_string()], + "SELECT * FROM table1".to_string(), + "client1".to_string(), + None, + ); + + let ticket2 = process_manager.clone().register_query( + "public".to_string(), + vec!["schema2".to_string()], + "SELECT * FROM table2".to_string(), + "client2".to_string(), + None, + ); + + let running_processes = process_manager.local_processes(Some("public")).unwrap(); + assert_eq!(running_processes.len(), 2); + + // Verify both processes are present + let ids: Vec = running_processes.iter().map(|p| p.id).collect(); + assert!(ids.contains(&ticket1.id)); + assert!(ids.contains(&ticket2.id)); + } + + #[tokio::test] + async fn test_multiple_catalogs() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + + let _ticket1 = process_manager.clone().register_query( + "catalog1".to_string(), + vec!["schema1".to_string()], + "SELECT * FROM table1".to_string(), + "client1".to_string(), + None, + ); + + let _ticket2 = process_manager.clone().register_query( + "catalog2".to_string(), + vec!["schema2".to_string()], + "SELECT * FROM table2".to_string(), + "client2".to_string(), + None, + ); + + // Test listing processes for specific catalog + let catalog1_processes = process_manager.local_processes(Some("catalog1")).unwrap(); + assert_eq!(catalog1_processes.len(), 1); + assert_eq!(&catalog1_processes[0].catalog, "catalog1"); + + let catalog2_processes = process_manager.local_processes(Some("catalog2")).unwrap(); + assert_eq!(catalog2_processes.len(), 1); + assert_eq!(&catalog2_processes[0].catalog, "catalog2"); + + // Test listing all processes + let all_processes = process_manager.local_processes(None).unwrap(); + assert_eq!(all_processes.len(), 2); + } + + #[tokio::test] + async fn test_deregister_query() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + + let ticket = process_manager.clone().register_query( + "public".to_string(), + vec!["test".to_string()], + "SELECT * FROM table".to_string(), + "client1".to_string(), + None, + ); + assert_eq!(process_manager.local_processes(None).unwrap().len(), 1); + process_manager.deregister_query("public".to_string(), ticket.id); + assert_eq!(process_manager.local_processes(None).unwrap().len(), 0); + } + + #[tokio::test] + async fn test_cancellation_handle() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + + let ticket = process_manager.clone().register_query( + "public".to_string(), + vec!["test".to_string()], + "SELECT * FROM table".to_string(), + "client1".to_string(), + None, + ); + + assert!(!ticket.cancellation_handle.is_cancelled()); + ticket.cancellation_handle.cancel(); + assert!(ticket.cancellation_handle.is_cancelled()); + } + + #[tokio::test] + async fn test_kill_local_process() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + + let ticket = process_manager.clone().register_query( + "public".to_string(), + vec!["test".to_string()], + "SELECT * FROM table".to_string(), + "client1".to_string(), + None, + ); + assert!(!ticket.cancellation_handle.is_cancelled()); + let killed = process_manager + .kill_process( + "127.0.0.1:8000".to_string(), + "public".to_string(), + ticket.id, + ) + .await + .unwrap(); + + assert!(killed); + assert_eq!(process_manager.local_processes(None).unwrap().len(), 0); + } + + #[tokio::test] + async fn test_kill_nonexistent_process() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + let killed = process_manager + .kill_process("127.0.0.1:8000".to_string(), "public".to_string(), 999) + .await + .unwrap(); + assert!(!killed); + } + + #[tokio::test] + async fn test_kill_process_nonexistent_catalog() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + let killed = process_manager + .kill_process("127.0.0.1:8000".to_string(), "nonexistent".to_string(), 1) + .await + .unwrap(); + assert!(!killed); + } + + #[tokio::test] + async fn test_process_info_fields() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + + let _ticket = process_manager.clone().register_query( + "test_catalog".to_string(), + vec!["schema1".to_string(), "schema2".to_string()], + "SELECT COUNT(*) FROM users WHERE age > 18".to_string(), + "test_client".to_string(), + Some(42), + ); + + let processes = process_manager.local_processes(None).unwrap(); + assert_eq!(processes.len(), 1); + + let process = &processes[0]; + assert_eq!(process.id, 42); + assert_eq!(&process.catalog, "test_catalog"); + assert_eq!(process.schemas, vec!["schema1", "schema2"]); + assert_eq!(&process.query, "SELECT COUNT(*) FROM users WHERE age > 18"); + assert_eq!(&process.client, "test_client"); + assert_eq!(&process.frontend, "127.0.0.1:8000"); + assert!(process.start_timestamp > 0); + } + + #[tokio::test] + async fn test_ticket_drop_deregisters_process() { + let process_manager = Arc::new(ProcessManager::new("127.0.0.1:8000".to_string(), None)); + + { + let _ticket = process_manager.clone().register_query( + "public".to_string(), + vec!["test".to_string()], + "SELECT * FROM table".to_string(), + "client1".to_string(), + None, + ); + + // Process should be registered + assert_eq!(process_manager.local_processes(None).unwrap().len(), 1); + } // ticket goes out of scope here + + // Process should be automatically deregistered + assert_eq!(process_manager.local_processes(None).unwrap().len(), 0); + } } diff --git a/src/common/base/src/cancellation.rs b/src/common/base/src/cancellation.rs new file mode 100644 index 0000000000..155b964158 --- /dev/null +++ b/src/common/base/src/cancellation.rs @@ -0,0 +1,240 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! [CancellationHandle] is used to compose with manual implementation of [futures::future::Future] +//! or [futures::stream::Stream] to facilitate cancellation. +//! See example in [frontend::stream_wrapper::CancellableStreamWrapper] and [CancellableFuture]. + +use std::fmt::{Debug, Display, Formatter}; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::task::AtomicWaker; +use pin_project::pin_project; + +#[derive(Default)] +pub struct CancellationHandle { + waker: AtomicWaker, + cancelled: AtomicBool, +} + +impl Debug for CancellationHandle { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CancellationHandle") + .field("cancelled", &self.is_cancelled()) + .finish() + } +} + +impl CancellationHandle { + pub fn waker(&self) -> &AtomicWaker { + &self.waker + } + + /// Cancels a future or stream. + pub fn cancel(&self) { + if self + .cancelled + .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + self.waker.wake(); + } + } + + /// Is this handle cancelled. + pub fn is_cancelled(&self) -> bool { + self.cancelled.load(Ordering::Relaxed) + } +} + +#[pin_project] +#[derive(Debug, Clone)] +pub struct CancellableFuture { + #[pin] + fut: T, + handle: Arc, +} + +impl CancellableFuture { + pub fn new(fut: T, handle: Arc) -> Self { + Self { fut, handle } + } +} + +impl Future for CancellableFuture +where + T: Future, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut().project(); + // Check if the task has been aborted + if this.handle.is_cancelled() { + return Poll::Ready(Err(Cancelled)); + } + + if let Poll::Ready(x) = this.fut.poll(cx) { + return Poll::Ready(Ok(x)); + } + + this.handle.waker().register(cx.waker()); + if this.handle.is_cancelled() { + return Poll::Ready(Err(Cancelled)); + } + Poll::Pending + } +} + +#[derive(Copy, Clone, Debug)] +pub struct Cancelled; + +impl Display for Cancelled { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Future has been cancelled") + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use tokio::time::{sleep, timeout}; + + use crate::cancellation::{CancellableFuture, CancellationHandle, Cancelled}; + + #[tokio::test] + async fn test_cancellable_future_completes_normally() { + let handle = Arc::new(CancellationHandle::default()); + let future = async { 42 }; + let cancellable = CancellableFuture::new(future, handle); + + let result = cancellable.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + } + + #[tokio::test] + async fn test_cancellable_future_cancelled_before_start() { + let handle = Arc::new(CancellationHandle::default()); + handle.cancel(); + + let future = async { 42 }; + let cancellable = CancellableFuture::new(future, handle); + + let result = cancellable.await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Cancelled)); + } + + #[tokio::test] + async fn test_cancellable_future_cancelled_during_execution() { + let handle = Arc::new(CancellationHandle::default()); + let handle_clone = handle.clone(); + + // Create a future that sleeps for a long time + let future = async { + sleep(Duration::from_secs(10)).await; + 42 + }; + let cancellable = CancellableFuture::new(future, handle); + + // Cancel the future after a short delay + tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + handle_clone.cancel(); + }); + + let result = cancellable.await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), Cancelled)); + } + + #[tokio::test] + async fn test_cancellable_future_completes_before_cancellation() { + let handle = Arc::new(CancellationHandle::default()); + let handle_clone = handle.clone(); + + // Create a future that completes quickly + let future = async { + sleep(Duration::from_millis(10)).await; + 42 + }; + let cancellable = CancellableFuture::new(future, handle); + + // Try to cancel after the future should have completed + tokio::spawn(async move { + sleep(Duration::from_millis(100)).await; + handle_clone.cancel(); + }); + + let result = cancellable.await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 42); + } + + #[tokio::test] + async fn test_cancellation_handle_is_cancelled() { + let handle = CancellationHandle::default(); + assert!(!handle.is_cancelled()); + + handle.cancel(); + assert!(handle.is_cancelled()); + } + + #[tokio::test] + async fn test_multiple_cancellable_futures_with_same_handle() { + let handle = Arc::new(CancellationHandle::default()); + + let future1 = CancellableFuture::new(async { 1 }, handle.clone()); + let future2 = CancellableFuture::new(async { 2 }, handle.clone()); + + // Cancel before starting + handle.cancel(); + + let (result1, result2) = tokio::join!(future1, future2); + + assert!(result1.is_err()); + assert!(result2.is_err()); + assert!(matches!(result1.unwrap_err(), Cancelled)); + assert!(matches!(result2.unwrap_err(), Cancelled)); + } + + #[tokio::test] + async fn test_cancellable_future_with_timeout() { + let handle = Arc::new(CancellationHandle::default()); + let future = async { + sleep(Duration::from_secs(1)).await; + 42 + }; + let cancellable = CancellableFuture::new(future, handle.clone()); + + // Use timeout to ensure the test doesn't hang + let result = timeout(Duration::from_millis(100), cancellable).await; + + // Should timeout because the future takes 1 second but we timeout after 100ms + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_cancelled_display() { + let cancelled = Cancelled; + assert_eq!(format!("{}", cancelled), "Future has been cancelled"); + } +} diff --git a/src/common/base/src/lib.rs b/src/common/base/src/lib.rs index 62a801d946..9fec936a13 100644 --- a/src/common/base/src/lib.rs +++ b/src/common/base/src/lib.rs @@ -14,6 +14,7 @@ pub mod bit_vec; pub mod bytes; +pub mod cancellation; pub mod plugins; pub mod range_read; #[allow(clippy::all)] diff --git a/src/common/frontend/src/error.rs b/src/common/frontend/src/error.rs index f4046a1fd9..cee8c6df77 100644 --- a/src/common/frontend/src/error.rs +++ b/src/common/frontend/src/error.rs @@ -42,8 +42,8 @@ pub enum Error { location: Location, }, - #[snafu(display("Failed to invoke list process service"))] - ListProcess { + #[snafu(display("Failed to invoke frontend service"))] + InvokeFrontend { #[snafu(source)] error: tonic::Status, #[snafu(implicit)] @@ -67,7 +67,7 @@ impl ErrorExt for Error { External { source, .. } => source.status_code(), Meta { source, .. } => source.status_code(), ParseProcessId { .. } => StatusCode::InvalidArguments, - ListProcess { .. } => StatusCode::External, + InvokeFrontend { .. } => StatusCode::Unexpected, CreateChannel { source, .. } => source.status_code(), } } diff --git a/src/common/frontend/src/selector.rs b/src/common/frontend/src/selector.rs index 4e36b67b18..3536ec85d8 100644 --- a/src/common/frontend/src/selector.rs +++ b/src/common/frontend/src/selector.rs @@ -16,9 +16,13 @@ use std::time::Duration; use common_grpc::channel_manager::{ChannelConfig, ChannelManager}; use common_meta::cluster::{ClusterInfo, NodeInfo, Role}; -use greptime_proto::v1::frontend::{frontend_client, ListProcessRequest, ListProcessResponse}; +use greptime_proto::v1::frontend::{ + frontend_client, KillProcessRequest, KillProcessResponse, ListProcessRequest, + ListProcessResponse, +}; use meta_client::MetaClientRef; use snafu::ResultExt; +use tonic::Response; use crate::error; use crate::error::{MetaSnafu, Result}; @@ -28,18 +32,28 @@ pub type FrontendClientPtr = Box; #[async_trait::async_trait] pub trait FrontendClient: Send { async fn list_process(&mut self, req: ListProcessRequest) -> Result; + + async fn kill_process(&mut self, req: KillProcessRequest) -> Result; } #[async_trait::async_trait] impl FrontendClient for frontend_client::FrontendClient { async fn list_process(&mut self, req: ListProcessRequest) -> Result { - let response: ListProcessResponse = frontend_client::FrontendClient::< - tonic::transport::channel::Channel, - >::list_process(self, req) + frontend_client::FrontendClient::::list_process( + self, req, + ) .await - .context(error::ListProcessSnafu)? - .into_inner(); - Ok(response) + .context(error::InvokeFrontendSnafu) + .map(Response::into_inner) + } + + async fn kill_process(&mut self, req: KillProcessRequest) -> Result { + frontend_client::FrontendClient::::kill_process( + self, req, + ) + .await + .context(error::InvokeFrontendSnafu) + .map(Response::into_inner) } } diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index 6c89d975e5..3e48324821 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -173,6 +173,7 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + #[snafu(display("Stream timeout"))] StreamTimeout { #[snafu(implicit)] @@ -180,6 +181,7 @@ pub enum Error { #[snafu(source)] error: tokio::time::error::Elapsed, }, + #[snafu(display("RecordBatch slice index overflow: {visit_index} > {size}"))] RecordBatchSliceIndexOverflow { #[snafu(implicit)] @@ -187,6 +189,12 @@ pub enum Error { size: usize, visit_index: usize, }, + + #[snafu(display("Stream has been cancelled"))] + StreamCancelled { + #[snafu(implicit)] + location: Location, + }, } impl ErrorExt for Error { @@ -221,6 +229,8 @@ impl ErrorExt for Error { } Error::StreamTimeout { .. } => StatusCode::Cancelled, + + Error::StreamCancelled { .. } => StatusCode::Cancelled, } } diff --git a/src/flow/src/server.rs b/src/flow/src/server.rs index a53888f4b6..e41168f5d3 100644 --- a/src/flow/src/server.rs +++ b/src/flow/src/server.rs @@ -578,6 +578,7 @@ impl FrontendInvoker { layered_cache_registry.clone(), inserter.clone(), table_route_cache, + None, )); let invoker = FrontendInvoker::new(inserter, deleter, statement_executor); diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 9a25b1485f..15fd74cb32 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -70,6 +70,7 @@ store-api.workspace = true substrait.workspace = true table.workspace = true tokio.workspace = true +tokio-util.workspace = true toml.workspace = true tonic.workspace = true diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 7880656f51..44f6dfdeb0 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -357,6 +357,12 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Query has been cancelled"))] + Cancelled { + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -435,6 +441,8 @@ impl ErrorExt for Error { Error::InFlightWriteBytesExceeded { .. } => StatusCode::RateLimited, Error::DataFusion { error, .. } => datafusion_status_code::(error, None), + + Error::Cancelled { .. } => StatusCode::Cancelled, } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index cc1838aff1..9ca0d09a65 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -33,6 +33,7 @@ use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; use catalog::process_manager::ProcessManagerRef; use catalog::CatalogManagerRef; use client::OutputData; +use common_base::cancellation::CancellableFuture; use common_base::Plugins; use common_config::KvBackendConfig; use common_error::ext::{BoxedError, ErrorExt}; @@ -81,7 +82,7 @@ use crate::error::{ }; use crate::limiter::LimiterRef; use crate::slow_query_recorder::SlowQueryRecorder; -use crate::stream_wrapper::StreamWrapper; +use crate::stream_wrapper::CancellableStreamWrapper; /// The frontend instance contains necessary components, and implements many /// traits, like [`servers::query_handler::grpc::GrpcQueryHandler`], @@ -187,63 +188,71 @@ impl Instance { None, ); - let output = match stmt { - Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => { - // TODO: remove this when format is supported in datafusion - if let Statement::Explain(explain) = &stmt { - if let Some(format) = explain.format() { - query_ctx.set_explain_format(format.to_string()); + let query_fut = async { + match stmt { + Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => { + // TODO: remove this when format is supported in datafusion + if let Statement::Explain(explain) = &stmt { + if let Some(format) = explain.format() { + query_ctx.set_explain_format(format.to_string()); + } } + + let stmt = QueryStatement::Sql(stmt); + let plan = self + .statement_executor + .plan(&stmt, query_ctx.clone()) + .await?; + + let QueryStatement::Sql(stmt) = stmt else { + unreachable!() + }; + query_interceptor.pre_execute(&stmt, Some(&plan), query_ctx.clone())?; + self.statement_executor + .exec_plan(plan, query_ctx) + .await + .context(TableOperationSnafu) } + Statement::Tql(tql) => { + let plan = self + .statement_executor + .plan_tql(tql.clone(), &query_ctx) + .await?; - let stmt = QueryStatement::Sql(stmt); - let plan = self - .statement_executor - .plan(&stmt, query_ctx.clone()) - .await?; - - let QueryStatement::Sql(stmt) = stmt else { - unreachable!() - }; - query_interceptor.pre_execute(&stmt, Some(&plan), query_ctx.clone())?; - - self.statement_executor.exec_plan(plan, query_ctx).await - } - Statement::Tql(tql) => { - let plan = self - .statement_executor - .plan_tql(tql.clone(), &query_ctx) - .await?; - - query_interceptor.pre_execute( - &Statement::Tql(tql), - Some(&plan), - query_ctx.clone(), - )?; - - self.statement_executor.exec_plan(plan, query_ctx).await - } - _ => { - query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?; - - self.statement_executor.execute_sql(stmt, query_ctx).await + query_interceptor.pre_execute( + &Statement::Tql(tql), + Some(&plan), + query_ctx.clone(), + )?; + self.statement_executor + .exec_plan(plan, query_ctx) + .await + .context(TableOperationSnafu) + } + _ => { + query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?; + self.statement_executor + .execute_sql(stmt, query_ctx) + .await + .context(TableOperationSnafu) + } } }; - match output { - Ok(output) => { + CancellableFuture::new(query_fut, ticket.cancellation_handle.clone()) + .await + .map_err(|_| error::CancelledSnafu.build())? + .map(|output| { let Output { meta, data } = output; let data = match data { OutputData::Stream(stream) => { - OutputData::Stream(Box::pin(StreamWrapper::new(stream, ticket))) + OutputData::Stream(Box::pin(CancellableStreamWrapper::new(stream, ticket))) } other => other, }; - Ok(Output { data, meta }) - } - Err(e) => Err(e).context(TableOperationSnafu), - } + Output { data, meta } + }) } } @@ -605,6 +614,8 @@ pub fn check_permission( } // cursor operations are always allowed once it's created Statement::FetchCursor(_) | Statement::CloseCursor(_) => {} + // User can only kill process in their own catalog. + Statement::Kill(_) => {} } Ok(()) } diff --git a/src/frontend/src/instance/builder.rs b/src/frontend/src/instance/builder.rs index 3a5dfcca3b..e9c132da42 100644 --- a/src/frontend/src/instance/builder.rs +++ b/src/frontend/src/instance/builder.rs @@ -180,6 +180,7 @@ impl FrontendBuilder { local_cache_invalidator, inserter.clone(), table_route_cache, + Some(process_manager.clone()), )); let pipeline_operator = Arc::new(PipelineOperator::new( diff --git a/src/frontend/src/stream_wrapper.rs b/src/frontend/src/stream_wrapper.rs index c38996c57c..2c1f4519b4 100644 --- a/src/frontend/src/stream_wrapper.rs +++ b/src/frontend/src/stream_wrapper.rs @@ -15,37 +15,52 @@ use std::pin::Pin; use std::task::{Context, Poll}; +use catalog::process_manager::Ticket; use common_recordbatch::adapter::RecordBatchMetrics; use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream}; use datatypes::schema::SchemaRef; use futures::Stream; -pub struct StreamWrapper { +pub struct CancellableStreamWrapper { inner: SendableRecordBatchStream, - _attachment: T, + ticket: Ticket, } -impl Unpin for StreamWrapper {} +impl Unpin for CancellableStreamWrapper {} -impl StreamWrapper { - pub fn new(stream: SendableRecordBatchStream, attachment: T) -> Self { +impl CancellableStreamWrapper { + pub fn new(stream: SendableRecordBatchStream, ticket: Ticket) -> Self { Self { inner: stream, - _attachment: attachment, + ticket, } } } -impl Stream for StreamWrapper { +impl Stream for CancellableStreamWrapper { type Item = common_recordbatch::error::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = &mut *self; - Pin::new(&mut this.inner).poll_next(cx) + if this.ticket.cancellation_handle.is_cancelled() { + return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail())); + } + + if let Poll::Ready(res) = Pin::new(&mut this.inner).poll_next(cx) { + return Poll::Ready(res); + } + + // on pending, register cancellation waker. + this.ticket.cancellation_handle.waker().register(cx.waker()); + // check if canceled again. + if this.ticket.cancellation_handle.is_cancelled() { + return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail())); + } + Poll::Pending } } -impl RecordBatchStream for StreamWrapper { +impl RecordBatchStream for CancellableStreamWrapper { fn schema(&self) -> SchemaRef { self.inner.schema() } @@ -58,3 +73,295 @@ impl RecordBatchStream for StreamWrapper { self.inner.metrics() } } + +#[cfg(test)] +mod tests { + use std::pin::Pin; + use std::sync::Arc; + use std::task::{Context, Poll}; + use std::time::Duration; + + use catalog::process_manager::ProcessManager; + use common_recordbatch::adapter::RecordBatchMetrics; + use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream}; + use datatypes::data_type::ConcreteDataType; + use datatypes::prelude::VectorRef; + use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; + use datatypes::vectors::Int32Vector; + use futures::{Stream, StreamExt}; + use tokio::time::{sleep, timeout}; + + use super::CancellableStreamWrapper; + + // Mock stream for testing + struct MockRecordBatchStream { + schema: SchemaRef, + batches: Vec>, + current: usize, + delay: Option, + } + + impl MockRecordBatchStream { + fn new(batches: Vec>) -> Self { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "test_col", + ConcreteDataType::int32_datatype(), + false, + )])); + + Self { + schema, + batches, + current: 0, + delay: None, + } + } + + fn with_delay(mut self, delay: Duration) -> Self { + self.delay = Some(delay); + self + } + } + + impl Stream for MockRecordBatchStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(delay) = self.delay { + // Simulate async delay + let waker = cx.waker().clone(); + let delay_clone = delay; + tokio::spawn(async move { + sleep(delay_clone).await; + waker.wake(); + }); + self.delay = None; // Only delay once + return Poll::Pending; + } + + if self.current >= self.batches.len() { + return Poll::Ready(None); + } + + let batch = self.batches[self.current].as_ref().unwrap().clone(); + self.current += 1; + Poll::Ready(Some(Ok(batch))) + } + } + + impl RecordBatchStream for MockRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + None + } + + fn metrics(&self) -> Option { + None + } + } + + fn create_test_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ColumnSchema::new( + "test_col", + ConcreteDataType::int32_datatype(), + false, + )])); + RecordBatch::new( + schema, + vec![Arc::new(Int32Vector::from_values(0..3)) as VectorRef], + ) + .unwrap() + } + + #[tokio::test] + async fn test_stream_completes_normally() { + let batch = create_test_batch(); + let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + ); + + let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket); + + let result = cancellable_stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_ok()); + + let end_result = cancellable_stream.next().await; + assert!(end_result.is_none()); + } + + #[tokio::test] + async fn test_stream_cancelled_before_start() { + let batch = create_test_batch(); + let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + ); + + // Cancel before creating the wrapper + ticket.cancellation_handle.cancel(); + + let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket); + + let result = cancellable_stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_err()); + } + + #[tokio::test] + async fn test_stream_cancelled_during_execution() { + let batch = create_test_batch(); + let mock_stream = + MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(100)); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + ); + let cancellation_handle = ticket.cancellation_handle.clone(); + + let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket); + + // Cancel after a short delay + tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + cancellation_handle.cancel(); + }); + + let result = cancellable_stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_err()); + } + + #[tokio::test] + async fn test_stream_completes_before_cancellation() { + let batch = create_test_batch(); + let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + ); + let cancellation_handle = ticket.cancellation_handle.clone(); + + let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket); + + // Try to cancel after the stream should have completed + tokio::spawn(async move { + sleep(Duration::from_millis(100)).await; + cancellation_handle.cancel(); + }); + + let result = cancellable_stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_ok()); + } + + #[tokio::test] + async fn test_multiple_batches() { + let batch1 = create_test_batch(); + let batch2 = create_test_batch(); + let mock_stream = MockRecordBatchStream::new(vec![Ok(batch1), Ok(batch2)]); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + ); + + let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket); + + // First batch + let result1 = cancellable_stream.next().await; + assert!(result1.is_some()); + assert!(result1.unwrap().is_ok()); + + // Second batch + let result2 = cancellable_stream.next().await; + assert!(result2.is_some()); + assert!(result2.unwrap().is_ok()); + + // End of stream + let end_result = cancellable_stream.next().await; + assert!(end_result.is_none()); + } + + #[tokio::test] + async fn test_record_batch_stream_methods() { + let batch = create_test_batch(); + let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + ); + + let cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket); + + // Test schema method + let schema = cancellable_stream.schema(); + assert_eq!(schema.column_schemas().len(), 1); + assert_eq!(schema.column_schemas()[0].name, "test_col"); + + // Test output_ordering method + assert!(cancellable_stream.output_ordering().is_none()); + + // Test metrics method + assert!(cancellable_stream.metrics().is_none()); + } + + #[tokio::test] + async fn test_cancellation_during_pending_poll() { + let batch = create_test_batch(); + let mock_stream = + MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(200)); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + ); + let cancellation_handle = ticket.cancellation_handle.clone(); + + let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket); + + // Cancel while the stream is pending + tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + cancellation_handle.cancel(); + }); + + let result = timeout(Duration::from_millis(300), cancellable_stream.next()).await; + assert!(result.is_ok()); + let stream_result = result.unwrap(); + assert!(stream_result.is_some()); + assert!(stream_result.unwrap().is_err()); + } +} diff --git a/src/operator/Cargo.toml b/src/operator/Cargo.toml index c0d7f5fd17..f1d2c641e6 100644 --- a/src/operator/Cargo.toml +++ b/src/operator/Cargo.toml @@ -26,6 +26,7 @@ common-base.workspace = true common-catalog.workspace = true common-datasource.workspace = true common-error.workspace = true +common-frontend.workspace = true common-function.workspace = true common-grpc.workspace = true common-grpc-expr.workspace = true diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index 1242138539..900d12b272 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -837,6 +837,15 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display("Invalid process id: {}", id))] + InvalidProcessId { id: String }, + + #[snafu(display("ProcessManager is not present, this can be caused by misconfiguration."))] + ProcessManagerMissing { + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -964,6 +973,8 @@ impl ErrorExt for Error { Error::ColumnOptions { source, .. } => source.status_code(), Error::DecodeFlightData { source, .. } => source.status_code(), Error::ComputeArrow { .. } => StatusCode::Internal, + Error::InvalidProcessId { .. } => StatusCode::InvalidArguments, + Error::ProcessManagerMissing { .. } => StatusCode::Unexpected, } } diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index ad104a5ae9..7968280be0 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -21,6 +21,7 @@ mod cursor; mod ddl; mod describe; mod dml; +mod kill; mod set; mod show; mod tql; @@ -32,6 +33,7 @@ use std::time::Duration; use async_stream::stream; use catalog::kvbackend::KvBackendCatalogManager; +use catalog::process_manager::ProcessManagerRef; use catalog::CatalogManagerRef; use client::{OutputData, RecordBatches}; use common_error::ext::BoxedError; @@ -94,11 +96,13 @@ pub struct StatementExecutor { partition_manager: PartitionRuleManagerRef, cache_invalidator: CacheInvalidatorRef, inserter: InserterRef, + process_manager: Option, } pub type StatementExecutorRef = Arc; impl StatementExecutor { + #[allow(clippy::too_many_arguments)] pub fn new( catalog_manager: CatalogManagerRef, query_engine: QueryEngineRef, @@ -107,6 +111,7 @@ impl StatementExecutor { cache_invalidator: CacheInvalidatorRef, inserter: InserterRef, table_route_cache: TableRouteCacheRef, + process_manager: Option, ) -> Self { Self { catalog_manager, @@ -118,6 +123,7 @@ impl StatementExecutor { partition_manager: Arc::new(PartitionRuleManager::new(kv_backend, table_route_cache)), cache_invalidator, inserter, + process_manager, } } @@ -363,6 +369,7 @@ impl StatementExecutor { Statement::ShowSearchPath(_) => self.show_search_path(query_ctx).await, Statement::Use(db) => self.use_database(db, query_ctx).await, Statement::Admin(admin) => self.execute_admin_command(admin, query_ctx).await, + Statement::Kill(id) => self.execute_kill(query_ctx, id).await, } } diff --git a/src/operator/src/statement/kill.rs b/src/operator/src/statement/kill.rs new file mode 100644 index 0000000000..7da90c13cb --- /dev/null +++ b/src/operator/src/statement/kill.rs @@ -0,0 +1,46 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_frontend::DisplayProcessId; +use common_query::Output; +use common_telemetry::error; +use session::context::QueryContextRef; +use snafu::ResultExt; + +use crate::error; +use crate::statement::StatementExecutor; + +impl StatementExecutor { + pub async fn execute_kill( + &self, + query_ctx: QueryContextRef, + process_id: String, + ) -> crate::error::Result { + let Some(process_manager) = self.process_manager.as_ref() else { + error!("Process manager is not initialized"); + return error::ProcessManagerMissingSnafu.fail(); + }; + + let display_id = DisplayProcessId::try_from(process_id.as_str()) + .map_err(|_| error::InvalidProcessIdSnafu { id: process_id }.build())?; + + let current_user_catalog = query_ctx.current_catalog().to_string(); + process_manager + .kill_process(display_id.server_addr, current_user_catalog, display_id.id) + .await + .context(error::CatalogSnafu)?; + + Ok(Output::new_with_affected_rows(0)) + } +} diff --git a/src/servers/src/grpc/frontend_grpc_handler.rs b/src/servers/src/grpc/frontend_grpc_handler.rs index f12ae77a97..6823eb37a7 100644 --- a/src/servers/src/grpc/frontend_grpc_handler.rs +++ b/src/servers/src/grpc/frontend_grpc_handler.rs @@ -13,7 +13,9 @@ // limitations under the License. use api::v1::frontend::frontend_server::Frontend; -use api::v1::frontend::{ListProcessRequest, ListProcessResponse}; +use api::v1::frontend::{ + KillProcessRequest, KillProcessResponse, ListProcessRequest, ListProcessResponse, +}; use catalog::process_manager::ProcessManagerRef; use common_telemetry::error; use tonic::{Code, Request, Response, Status}; @@ -41,12 +43,27 @@ impl Frontend for FrontendGrpcHandler { } else { Some(list_process_request.catalog.as_str()) }; - match self.process_manager.local_processes(catalog) { - Ok(processes) => Ok(Response::new(ListProcessResponse { processes })), - Err(e) => { - error!(e; "Failed to handle list process request"); - Err(Status::new(Code::Internal, e.to_string())) - } - } + let processes = self.process_manager.local_processes(catalog).map_err(|e| { + error!(e; "Failed to handle list process request"); + Status::new(Code::Internal, e.to_string()) + })?; + Ok(Response::new(ListProcessResponse { processes })) + } + + async fn kill_process( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + let success = self + .process_manager + .kill_process(req.server_addr, req.catalog, req.process_id) + .await + .map_err(|e| { + error!(e; "Failed to handle kill process request"); + Status::new(Code::Internal, e.to_string()) + })?; + + Ok(Response::new(KillProcessResponse { success })) } } diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 2265a30bdd..63c6465cc7 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -121,78 +121,87 @@ impl ParserContext<'_> { /// Parses parser context to a set of statements. pub fn parse_statement(&mut self) -> Result { match self.parser.peek_token().token { - Token::Word(w) => { - match w.keyword { - Keyword::CREATE => { - let _ = self.parser.next_token(); - self.parse_create() - } + Token::Word(w) => match w.keyword { + Keyword::CREATE => { + let _ = self.parser.next_token(); + self.parse_create() + } - Keyword::EXPLAIN => { - let _ = self.parser.next_token(); - self.parse_explain() - } + Keyword::EXPLAIN => { + let _ = self.parser.next_token(); + self.parse_explain() + } - Keyword::SHOW => { - let _ = self.parser.next_token(); - self.parse_show() - } + Keyword::SHOW => { + let _ = self.parser.next_token(); + self.parse_show() + } - Keyword::DELETE => self.parse_delete(), + Keyword::DELETE => self.parse_delete(), - Keyword::DESCRIBE | Keyword::DESC => { - let _ = self.parser.next_token(); - self.parse_describe() - } + Keyword::DESCRIBE | Keyword::DESC => { + let _ = self.parser.next_token(); + self.parse_describe() + } - Keyword::INSERT => self.parse_insert(), + Keyword::INSERT => self.parse_insert(), - Keyword::REPLACE => self.parse_replace(), + Keyword::REPLACE => self.parse_replace(), - Keyword::SELECT | Keyword::WITH | Keyword::VALUES => self.parse_query(), + Keyword::SELECT | Keyword::WITH | Keyword::VALUES => self.parse_query(), - Keyword::ALTER => self.parse_alter(), + Keyword::ALTER => self.parse_alter(), - Keyword::DROP => self.parse_drop(), + Keyword::DROP => self.parse_drop(), - Keyword::COPY => self.parse_copy(), + Keyword::COPY => self.parse_copy(), - Keyword::TRUNCATE => self.parse_truncate(), + Keyword::TRUNCATE => self.parse_truncate(), - Keyword::SET => self.parse_set_variables(), + Keyword::SET => self.parse_set_variables(), - Keyword::ADMIN => self.parse_admin_command(), + Keyword::ADMIN => self.parse_admin_command(), - Keyword::NoKeyword - if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL => - { - self.parse_tql() - } + Keyword::NoKeyword + if w.quote_style.is_none() && w.value.to_uppercase() == tql_parser::TQL => + { + self.parse_tql() + } - Keyword::DECLARE => self.parse_declare_cursor(), + Keyword::DECLARE => self.parse_declare_cursor(), - Keyword::FETCH => self.parse_fetch_cursor(), + Keyword::FETCH => self.parse_fetch_cursor(), - Keyword::CLOSE => self.parse_close_cursor(), + Keyword::CLOSE => self.parse_close_cursor(), - Keyword::USE => { - let _ = self.parser.next_token(); + Keyword::USE => { + let _ = self.parser.next_token(); - let database_name = self.parser.parse_identifier().with_context(|_| { + let database_name = self.parser.parse_identifier().with_context(|_| { + error::UnexpectedSnafu { + expected: "a database name", + actual: self.peek_token_as_string(), + } + })?; + Ok(Statement::Use( + Self::canonicalize_identifier(database_name).value, + )) + } + + Keyword::KILL => { + let _ = self.parser.next_token(); + let process_id_ident = + self.parser.parse_literal_string().with_context(|_| { error::UnexpectedSnafu { - expected: "a database name", + expected: "process id string literal", actual: self.peek_token_as_string(), } })?; - Ok(Statement::Use( - Self::canonicalize_identifier(database_name).value, - )) - } - - // todo(hl) support more statements. - _ => self.unsupported(self.peek_token_as_string()), + Ok(Statement::Kill(process_id_ident)) } - } + + _ => self.unsupported(self.peek_token_as_string()), + }, Token::LParen => self.parse_query(), unexpected => self.unsupported(unexpected.to_string()), } diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index 4ab10d52ae..56492e75bc 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -138,6 +138,8 @@ pub enum Statement { FetchCursor(FetchCursor), // CLOSE CloseCursor(CloseCursor), + // KILL + Kill(String), } impl Display for Statement { @@ -194,6 +196,7 @@ impl Display for Statement { Statement::DeclareCursor(s) => s.fmt(f), Statement::FetchCursor(s) => s.fmt(f), Statement::CloseCursor(s) => s.fmt(f), + Statement::Kill(k) => k.fmt(f), } } }