diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 24075601f6..8cc6195e5f 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -1204,14 +1204,19 @@ fn should_track_plan_process(stmt: Option<&Statement>, plan: &LogicalPlan) -> bo #[cfg(test)] mod tests { use std::collections::HashMap; + use std::future::Future; + use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Barrier}; + use std::task::{Context, Poll}; use std::thread; use std::time::{Duration, Instant}; use api::v1::meta::{ProcedureDetailResponse, ReconcileRequest, ReconcileResponse}; use catalog::process_manager::ProcessManager; use common_base::Plugins; + use common_error::ext::{BoxedError, PlainError}; + use common_error::status_code::StatusCode; use common_meta::cache::LayeredCacheRegistryBuilder; use common_meta::kv_backend::memory::MemoryKvBackend; use common_meta::procedure_executor::{ExecutorContext, ProcedureExecutor}; @@ -1220,23 +1225,142 @@ mod tests { MigrateRegionRequest, MigrateRegionResponse, ProcedureStateResponse, }; use common_query::Output; + use common_recordbatch::{ + OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream, + }; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::dml::InsertOp; use datafusion_expr::{LogicalPlanBuilder, LogicalTableSource}; use datatypes::prelude::ConcreteDataType; - use datatypes::schema::{ColumnSchema, Schema as GtSchema}; + use datatypes::schema::{ColumnSchema, Schema as GtSchema, SchemaRef as GtSchemaRef}; use query::query_engine::options::QueryOptions; use session::context::{Channel, ConnInfo, QueryContext, QueryContextBuilder}; + use snafu::{Location, Snafu}; use sql::dialect::GreptimeDbDialect; + use store_api::data_source::DataSource; + use store_api::storage::ScanRequest; use strfmt::Format; - use table::metadata::{TableInfoBuilder, TableMetaBuilder}; + use table::metadata::{FilterPushDownType, TableInfo, TableInfoBuilder, TableMetaBuilder}; use table::test_util::EmptyTable; + use table::{Table, TableRef}; use tokio::sync::{mpsc, oneshot}; use super::*; use crate::frontend::FrontendOptions; use crate::instance::builder::FrontendBuilder; + #[derive(Debug, Snafu)] + enum TestError { + #[snafu(display("Failed to build test cache registry"))] + BuildCacheRegistry { + source: cache::error::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to build test table meta for table: {table_name}"))] + BuildTableMeta { + table_name: String, + source: table::metadata::TableMetaBuilderError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to build test table info for table: {table_name}"))] + BuildTableInfo { + table_name: String, + source: table::metadata::TableInfoBuilderError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to register test table: {table_name}"))] + RegisterTable { + table_name: String, + source: catalog::error::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to build test frontend instance"))] + BuildFrontend { + source: crate::error::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Expected exactly one output for SQL `{sql}`, got {actual}"))] + UnexpectedOutputCount { + sql: String, + actual: usize, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to execute SQL `{sql}`"))] + ExecuteSql { + sql: String, + source: crate::error::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Timed out waiting for insert-select start notification"))] + InsertStartTimeout { + source: tokio::time::error::Elapsed, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Insert-select start notification channel closed"))] + InsertStartChannelClosed { + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to release blocking insert-select interceptor"))] + ReleaseBlockedInsert { + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Timed out waiting for insert-select source to be polled"))] + SourcePollTimeout { + source: tokio::time::error::Elapsed, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Insert-select source poll notification channel closed"))] + SourcePollChannelClosed { + source: oneshot::error::RecvError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Timed out waiting for insert task to finish"))] + InsertTaskTimeout { + source: tokio::time::error::Elapsed, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Insert task panicked"))] + InsertTaskPanic { + source: tokio::task::JoinError, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Expected insert-select to be cancelled"))] + InsertSelectNotCancelled { + #[snafu(implicit)] + location: Location, + }, + } + + type TestResult = std::result::Result; + fn parse_one_sql(sql: &str) -> Statement { parse_stmt(sql, &GreptimeDbDialect {}).unwrap().remove(0) } @@ -1292,6 +1416,70 @@ mod tests { } } + struct PendingRecordBatchStream { + schema: GtSchemaRef, + polled_tx: Option>, + _finish_tx: oneshot::Sender<()>, + finish_rx: Pin>>, + } + + impl RecordBatchStream for PendingRecordBatchStream { + fn schema(&self) -> GtSchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + None + } + + fn metrics(&self) -> Option { + None + } + } + + impl Stream for PendingRecordBatchStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(polled_tx) = self.polled_tx.take() { + let _ = polled_tx.send(()); + } + + match self.finish_rx.as_mut().poll(cx) { + Poll::Ready(_) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + } + + impl Unpin for PendingRecordBatchStream {} + + struct PendingDataSource { + schema: GtSchemaRef, + polled_tx: std::sync::Mutex>>, + } + + impl DataSource for PendingDataSource { + fn get_stream( + &self, + _request: ScanRequest, + ) -> std::result::Result { + let (finish_tx, finish_rx) = oneshot::channel(); + let mut polled_tx = self.polled_tx.lock().map_err(|_| { + BoxedError::new(PlainError::new( + "pending data source lock poisoned".to_string(), + StatusCode::Unexpected, + )) + })?; + Ok(Box::pin(PendingRecordBatchStream { + schema: self.schema.clone(), + polled_tx: polled_tx.take(), + _finish_tx: finish_tx, + finish_rx: Box::pin(finish_rx), + })) + } + } + struct NoopProcedureExecutor; #[async_trait::async_trait] @@ -1353,18 +1541,18 @@ mod tests { fn test_cache_registry( kv_backend: common_meta::kv_backend::KvBackendRef, - ) -> common_meta::cache::LayeredCacheRegistryRef { - Arc::new( + ) -> TestResult { + Ok(Arc::new( cache::with_default_composite_cache_registry( LayeredCacheRegistryBuilder::default() .add_cache_registry(cache::build_fundamental_cache_registry(kv_backend)), ) - .unwrap() + .context(BuildCacheRegistrySnafu)? .build(), - ) + )) } - fn test_table(table_id: u32, table_name: &str) -> table::TableRef { + fn test_table_info(table_id: u32, table_name: &str) -> TestResult { let schema = Arc::new(GtSchema::new(vec![ ColumnSchema::new("id", ConcreteDataType::int32_datatype(), false), ColumnSchema::new( @@ -1380,36 +1568,85 @@ mod tests { .value_indices(vec![1]) .next_column_id(1024) .build() - .unwrap(); - let table_info = TableInfoBuilder::new(table_name, table_meta) + .with_context(|_| BuildTableMetaSnafu { + table_name: table_name.to_string(), + })?; + + TableInfoBuilder::new(table_name, table_meta) .table_id(table_id) .build() - .unwrap(); + .with_context(|_| BuildTableInfoSnafu { + table_name: table_name.to_string(), + }) + } - EmptyTable::from_table_info(&table_info) + fn test_table(table_id: u32, table_name: &str) -> TestResult { + let table_info = test_table_info(table_id, table_name)?; + Ok(EmptyTable::from_table_info(&table_info)) + } + + fn pending_table( + table_id: u32, + table_name: &str, + polled_tx: oneshot::Sender<()>, + ) -> TestResult { + let table_info = test_table_info(table_id, table_name)?; + let data_source = Arc::new(PendingDataSource { + schema: table_info.meta.schema.clone(), + polled_tx: std::sync::Mutex::new(Some(polled_tx)), + }); + + Ok(Arc::new(Table::new( + Arc::new(table_info), + FilterPushDownType::Unsupported, + data_source, + ))) + } + + async fn test_instance_with_tables( + source_table: TableRef, + target_table: TableRef, + ) -> TestResult { + test_instance_with_plugins(source_table, target_table, Plugins::new()).await } async fn test_instance_with_insert_select_interceptor( interceptor: SqlQueryInterceptorRef, - ) -> Instance { + ) -> TestResult { + let plugins = Plugins::new(); + plugins.insert::>(interceptor); + + test_instance_with_plugins( + test_table(1024, "source")?, + test_table(1025, "target")?, + plugins, + ) + .await + } + + async fn test_instance_with_plugins( + source_table: TableRef, + target_table: TableRef, + plugins: Plugins, + ) -> TestResult { let kv_backend = Arc::new(MemoryKvBackend::new()); let process_manager = Arc::new(ProcessManager::new("test-frontend".to_string(), None)); - let catalog_manager = - catalog::memory::MemoryCatalogManager::new_with_table(test_table(1024, "source")); + let catalog_manager = catalog::memory::MemoryCatalogManager::new_with_table(source_table); + let target_table_name = "target"; catalog_manager .register_table_sync(catalog::RegisterTableRequest { catalog: "greptime".to_string(), schema: "public".to_string(), - table_name: "target".to_string(), + table_name: target_table_name.to_string(), table_id: 1025, - table: test_table(1025, "target"), + table: target_table, }) - .unwrap(); + .with_context(|_| RegisterTableSnafu { + table_name: target_table_name.to_string(), + })?; catalog_manager.register_process_list_table(process_manager.clone()); - let cache_registry = test_cache_registry(kv_backend.clone()); - let plugins = Plugins::new(); - plugins.insert::>(interceptor); + let cache_registry = test_cache_registry(kv_backend.clone())?; FrontendBuilder::new( FrontendOptions::default(), @@ -1423,17 +1660,25 @@ mod tests { .with_plugin(plugins) .try_build() .await - .unwrap() + .context(BuildFrontendSnafu) } async fn execute_one_sql( instance: &Instance, sql: &str, query_ctx: QueryContextRef, - ) -> Result { + ) -> TestResult { let mut results = instance.do_query_inner(sql, query_ctx).await; - assert_eq!(1, results.len()); - results.remove(0) + ensure!( + results.len() == 1, + UnexpectedOutputCountSnafu { + sql: sql.to_string(), + actual: results.len(), + } + ); + results.remove(0).with_context(|_| ExecuteSqlSnafu { + sql: sql.to_string(), + }) } #[test] @@ -1588,12 +1833,12 @@ mod tests { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_insert_select_is_visible_in_show_processlist() { + async fn test_insert_select_is_visible_in_show_processlist() -> TestResult<()> { let insert_sql = "INSERT INTO target SELECT * FROM source"; let (started_tx, mut started_rx) = mpsc::unbounded_channel(); let (finish_tx, finish_rx) = oneshot::channel(); let interceptor = Arc::new(BlockingInsertSelectInterceptor::new(started_tx, finish_rx)); - let instance = Arc::new(test_instance_with_insert_select_interceptor(interceptor).await); + let instance = Arc::new(test_instance_with_insert_select_interceptor(interceptor).await?); let insert_task = tokio::spawn({ let instance = instance.clone(); @@ -1602,20 +1847,77 @@ mod tests { tokio::time::timeout(Duration::from_secs(5), started_rx.recv()) .await - .unwrap() - .unwrap(); + .context(InsertStartTimeoutSnafu)? + .context(InsertStartChannelClosedSnafu)?; - let output = execute_one_sql(&instance, "SHOW PROCESSLIST", test_query_ctx(43)) - .await - .unwrap(); + let output = execute_one_sql(&instance, "SHOW PROCESSLIST", test_query_ctx(43)).await?; let process_list = output.data.pretty_print().await; assert!( process_list.contains(insert_sql), "process list did not contain running insert:\n{process_list}" ); - finish_tx.send(()).unwrap(); - insert_task.await.unwrap().unwrap(); + finish_tx + .send(()) + .map_err(|_| ReleaseBlockedInsertSnafu.build())?; + insert_task.await.context(InsertTaskPanicSnafu)??; + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_kill_query_cancels_insert_select() -> TestResult<()> { + assert_kill_cancels_insert_select("KILL QUERY 4242").await + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_kill_process_id_cancels_insert_select() -> TestResult<()> { + assert_kill_cancels_insert_select("KILL 'test-frontend/4242'").await + } + + async fn assert_kill_cancels_insert_select(kill_sql: &str) -> TestResult<()> { + let insert_sql = "INSERT INTO target SELECT * FROM source"; + let (source_polled_tx, source_polled_rx) = oneshot::channel(); + let instance = Arc::new( + test_instance_with_tables( + pending_table(1024, "source", source_polled_tx)?, + test_table(1025, "target")?, + ) + .await?, + ); + + let insert_task = tokio::spawn({ + let instance = instance.clone(); + async move { execute_one_sql(&instance, insert_sql, test_query_ctx(4242)).await } + }); + + tokio::time::timeout(Duration::from_secs(5), source_polled_rx) + .await + .context(SourcePollTimeoutSnafu)? + .context(SourcePollChannelClosedSnafu)?; + + let output = execute_one_sql(&instance, kill_sql, test_query_ctx(43)).await?; + assert!(matches!(output.data, OutputData::AffectedRows(1))); + + let insert_result = tokio::time::timeout(Duration::from_secs(5), insert_task) + .await + .context(InsertTaskTimeoutSnafu)? + .context(InsertTaskPanicSnafu)?; + let err = match insert_result { + Ok(_) => return InsertSelectNotCancelledSnafu.fail(), + Err(TestError::ExecuteSql { source, .. }) => source, + Err(err) => return Err(err), + }; + assert_eq!(StatusCode::Cancelled, err.status_code()); + + let output = execute_one_sql(&instance, "SHOW PROCESSLIST", test_query_ctx(43)).await?; + let process_list = output.data.pretty_print().await; + assert!( + !process_list.contains(insert_sql), + "process list still contains killed insert:\n{process_list}" + ); + + Ok(()) } fn insert_dml_plan() -> LogicalPlan {