From 0f2f20d4b70e41f31a276f951de81dce5fb91fbc Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 3 Feb 2026 21:32:11 +0800 Subject: [PATCH] feat: reduce unit test suite wall time (#7657) * feat: reduce wait timt of from 57s to 0.6s Signed-off-by: Ruihang Xia * , , , Signed-off-by: Ruihang Xia * test_query_concurrently Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- src/frontend/src/frontend.rs | 36 ++++- src/servers/tests/mysql/mysql_server_test.rs | 19 ++- src/sql/src/partition.rs | 154 +++++++++++-------- tests-integration/src/standalone.rs | 21 ++- tests-integration/src/test_util.rs | 152 ++++++++++++++++++ tests-integration/tests/http.rs | 46 +++++- tests-integration/tests/sql.rs | 79 +++++++--- 7 files changed, 391 insertions(+), 116 deletions(-) diff --git a/src/frontend/src/frontend.rs b/src/frontend/src/frontend.rs index f42150582a..69ae59517e 100644 --- a/src/frontend/src/frontend.rs +++ b/src/frontend/src/frontend.rs @@ -207,6 +207,10 @@ mod tests { let mut requests = request.into_inner(); let suspend = self.suspend.clone(); async move { + // Make the heartbeat interval short in unit tests to reduce the waiting time. + // Only the handshake response needs to populate it (as metasrv does). + let heartbeat_interval_ms = Duration::from_millis(200).as_millis() as u64; + let mut is_handshake = true; while let Some(request) = requests.next().await { if let Err(e) = request { let _ = tx.send(Err(e)).await; @@ -220,9 +224,16 @@ mod tests { )), ..Default::default() }); + let heartbeat_config = + is_handshake.then_some(api::v1::meta::HeartbeatConfig { + heartbeat_interval_ms, + retry_interval_ms: heartbeat_interval_ms, + }); + is_handshake = false; let response = HeartbeatResponse { header: Some(ResponseHeader::success()), mailbox_message, + heartbeat_config, ..Default::default() }; @@ -376,6 +387,21 @@ mod tests { } } + async fn wait_for_suspend_state(frontend: &Frontend, expected: bool) { + let check = || frontend.instance.is_suspended() == expected; + if check() { + return; + } + + tokio::time::timeout(Duration::from_secs(5), async move { + while !check() { + tokio::time::sleep(Duration::from_millis(20)).await; + } + }) + .await + .unwrap(); + } + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_suspend_frontend() -> Result<()> { common_telemetry::init_default_ut_logging(); @@ -412,12 +438,6 @@ mod tests { let meta_client = create_meta_client(&meta_client_options, server.clone()).await; let frontend = create_frontend(&options, meta_client).await?; - use common_meta::distributed_time_constants::{ - BASE_HEARTBEAT_INTERVAL, frontend_heartbeat_interval, - }; - let frontend_heartbeat_interval = - frontend_heartbeat_interval(BASE_HEARTBEAT_INTERVAL) + Duration::from_secs(1); - tokio::time::sleep(frontend_heartbeat_interval).await; // initial state: not suspend: assert!(!frontend.instance.is_suspended()); verify_suspend_state_by_http(&frontend, Ok(r#"[{"records":{"schema":{"column_schemas":[{"name":"Int64(1)","data_type":"Int64"}]},"rows":[[1]],"total_rows":1}}]"#)).await; @@ -434,7 +454,7 @@ mod tests { // make heartbeat server returned "suspend" instruction, server.suspend.store(true, Ordering::Relaxed); - tokio::time::sleep(frontend_heartbeat_interval).await; + wait_for_suspend_state(&frontend, true).await; // ... then the frontend is suspended: assert!(frontend.instance.is_suspended()); verify_suspend_state_by_http( @@ -450,7 +470,7 @@ mod tests { // make heartbeat server NOT returned "suspend" instruction, server.suspend.store(false, Ordering::Relaxed); - tokio::time::sleep(frontend_heartbeat_interval).await; + wait_for_suspend_state(&frontend, false).await; // ... then frontend's suspend state is cleared: assert!(!frontend.instance.is_suspended()); verify_suspend_state_by_http(&frontend, Ok(r#"[{"records":{"schema":{"column_schemas":[{"name":"Int64(1)","data_type":"Int64"}]},"rows":[[1]],"total_rows":1}}]"#)).await; diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 16272af98f..3aa7b98f39 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -26,7 +26,6 @@ use datatypes::schema::{ColumnSchema, Schema}; use datatypes::value::Value; use mysql_async::prelude::*; use mysql_async::{Conn, Row, SslOpts}; -use rand::Rng; use servers::error::Result; use servers::install_ring_crypto_provider; use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; @@ -436,19 +435,23 @@ async fn test_query_concurrently() -> Result<()> { let server_port = server_addr.port(); let threads = 4; - let expect_executed_queries_per_worker = 1000; + let expect_executed_queries_per_worker = 200; + let queries = Arc::new( + (0..100u32) + .map(|expected| format!("SELECT uint32s FROM numbers WHERE uint32s = {expected}")) + .collect::>(), + ); let mut join_handles = vec![]; - for _ in 0..threads { + for worker_id in 0..threads { + let queries = queries.clone(); join_handles.push(tokio::spawn(async move { let mut connection = create_connection_default_db_name(server_port, false) .await .unwrap(); - for _ in 0..expect_executed_queries_per_worker { - let expected: u32 = rand::rng().random_range(0..100); + for i in 0..expect_executed_queries_per_worker { + let expected: u32 = ((i + worker_id) % 100) as u32; let result: u32 = connection - .query_first(format!( - "SELECT uint32s FROM numbers WHERE uint32s = {expected}" - )) + .query_first(&queries[expected as usize]) .await .unwrap() .unwrap(); diff --git a/src/sql/src/partition.rs b/src/sql/src/partition.rs index 47b74f05c1..0f1a7526dc 100644 --- a/src/sql/src/partition.rs +++ b/src/sql/src/partition.rs @@ -120,8 +120,6 @@ fn partition_rules_for_uuid(partition_num: u32, ident: &str) -> Result #[cfg(test)] mod tests { - use std::collections::HashMap; - use sqlparser::ast::{Expr, ValueWithSpan}; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -180,20 +178,16 @@ mod tests { } fn check_distribution(test_partition: u32, test_uuid_num: usize) -> bool { - // Generate test_uuid_num random uuids. - let uuids = (0..test_uuid_num) - .map(|_| Uuid::new_v4().to_string().replace("-", "").to_lowercase()) - .collect::>(); - // Generate the partition rules. let rules = partition_rules_for_uuid(test_partition, "test_trace_id").unwrap(); + let upper_bounds = upper_bounds_from_rules(&rules); // Collect the number of partitions for each uuid. - let mut stats = HashMap::new(); - for uuid in uuids { - let partition = allocate_partition_for_uuid(uuid.clone(), &rules); - // Count the number of uuids in each partition. - *stats.entry(partition).or_insert(0) += 1; + let mut counts = vec![0usize; test_partition as usize]; + for _ in 0..test_uuid_num { + let uuid = Uuid::new_v4().simple().to_string(); + let partition = allocate_partition_for_uuid(uuid.as_str(), &upper_bounds); + counts[partition] += 1; } // Check if the partition distribution is uniform. @@ -203,7 +197,7 @@ mod tests { let tolerance = 100.0 / test_partition as f64 * 0.30; // For each partition, its ratio should be as close as possible to the expected ratio. - for (_, count) in stats { + for count in counts { let ratio = (count as f64 / test_uuid_num as f64) * 100.0; if (ratio - expected_ratio).abs() >= tolerance { return false; @@ -213,58 +207,94 @@ mod tests { true } - fn allocate_partition_for_uuid(uuid: String, rules: &[Expr]) -> usize { - for (i, rule) in rules.iter().enumerate() { - if let Expr::BinaryOp { left, op: _, right } = rule { - if i == 0 { - // Hit the leftmost rule. - if let Expr::Value(ValueWithSpan { - value: Value::SingleQuotedString(leftmost), - .. - }) = *right.clone() - && uuid < leftmost - { - return i; - } - } else if i == rules.len() - 1 { - // Hit the rightmost rule. - if let Expr::Value(ValueWithSpan { - value: Value::SingleQuotedString(rightmost), - .. - }) = *right.clone() - && uuid >= rightmost - { - return i; - } - } else { - // Hit the middle rules. - if let Expr::BinaryOp { - left: _, - op: _, - right: inner_right, - } = *left.clone() - && let Expr::Value(ValueWithSpan { - value: Value::SingleQuotedString(lower), - .. - }) = *inner_right.clone() - && let Expr::BinaryOp { - left: _, - op: _, - right: inner_right, - } = *right.clone() - && let Expr::Value(ValueWithSpan { - value: Value::SingleQuotedString(upper), - .. - }) = *inner_right.clone() - && uuid >= lower - && uuid < upper - { - return i; - } + fn upper_bounds_from_rules<'a>(rules: &'a [Expr]) -> Vec<&'a str> { + let mut upper_bounds = Vec::with_capacity(rules.len().saturating_sub(1)); + let mut prev_upper: Option<&'a str> = None; + + for (idx, rule) in rules.iter().enumerate() { + let (lower, upper) = extract_rule_bounds(rule); + match idx { + 0 => { + assert!(lower.is_none()); + assert!(upper.is_some()); } + idx if idx == rules.len() - 1 => { + assert_eq!(lower, prev_upper); + assert!(upper.is_none()); + } + _ => { + assert_eq!(lower, prev_upper); + assert!(upper.is_some()); + } + } + + if idx < rules.len() - 1 { + upper_bounds.push(upper.unwrap()); + } + + prev_upper = upper; + } + + upper_bounds + } + + fn extract_rule_bounds(rule: &Expr) -> (Option<&str>, Option<&str>) { + fn extract_single_quoted_string(expr: &Expr) -> Option<&str> { + match expr { + Expr::Value(ValueWithSpan { + value: Value::SingleQuotedString(value), + .. + }) => Some(value.as_str()), + _ => None, } } - panic!("No partition found for uuid: {}, rules: {:?}", uuid, rules); + match rule { + Expr::BinaryOp { + op: BinaryOperator::Lt, + right, + .. + } => (None, extract_single_quoted_string(right)), + Expr::BinaryOp { + op: BinaryOperator::GtEq, + right, + .. + } => (extract_single_quoted_string(right), None), + Expr::BinaryOp { + op: BinaryOperator::And, + left, + right, + } => { + let lower = match left.as_ref() { + Expr::BinaryOp { + op: BinaryOperator::GtEq, + right, + .. + } => extract_single_quoted_string(right), + _ => None, + }; + let upper = match right.as_ref() { + Expr::BinaryOp { + op: BinaryOperator::Lt, + right, + .. + } => extract_single_quoted_string(right), + _ => None, + }; + (lower, upper) + } + _ => (None, None), + } + } + + fn allocate_partition_for_uuid(uuid: &str, upper_bounds: &[&str]) -> usize { + upper_bounds.partition_point(|upper| uuid >= *upper) + } + + #[test] + fn test_extract_rule_bounds() { + let rules = partition_rules_for_uuid(16, "trace_id").unwrap(); + let upper_bounds = upper_bounds_from_rules(&rules); + assert_eq!(upper_bounds.len(), 15); } } diff --git a/tests-integration/src/standalone.rs b/tests-integration/src/standalone.rs index af7047cdaf..acdd0e60b5 100644 --- a/tests-integration/src/standalone.rs +++ b/tests-integration/src/standalone.rs @@ -79,6 +79,7 @@ pub struct GreptimeDbStandaloneBuilder { store_providers: Option>, default_store: Option, plugin: Option, + slow_query_options: SlowQueryOptions, } impl GreptimeDbStandaloneBuilder { @@ -90,6 +91,12 @@ impl GreptimeDbStandaloneBuilder { default_store: None, datanode_wal_config: DatanodeWalConfig::default(), metasrv_wal_config: MetasrvWalConfig::default(), + // Enable slow query log with 1s threshold by default for integration tests. + slow_query_options: SlowQueryOptions { + enable: true, + threshold: Duration::from_secs(1), + ..Default::default() + }, } } @@ -118,6 +125,12 @@ impl GreptimeDbStandaloneBuilder { } } + #[must_use] + pub fn with_slow_query_threshold(mut self, threshold: Duration) -> Self { + self.slow_query_options.threshold = threshold; + self + } + #[must_use] pub fn with_datanode_wal_config(mut self, datanode_wal_config: DatanodeWalConfig) -> Self { self.datanode_wal_config = datanode_wal_config; @@ -332,13 +345,7 @@ impl GreptimeDbStandaloneBuilder { metadata_store: kv_backend_config, wal: self.metasrv_wal_config.clone().into(), grpc: GrpcOptions::default().with_server_addr("127.0.0.1:4001"), - // Enable slow query log with 1s threshold to run the slow query test. - slow_query: SlowQueryOptions { - enable: true, - // Set the threshold to 1s to run the slow query test. - threshold: Duration::from_secs(1), - ..Default::default() - }, + slow_query: self.slow_query_options.clone(), ..StandaloneOptions::default() }; diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 9d7ee592a4..0c232c3c69 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -395,6 +395,18 @@ async fn setup_standalone_instance( .await } +async fn setup_standalone_instance_with_slow_query_threshold( + test_name: &str, + store_type: StorageType, + slow_query_threshold: std::time::Duration, +) -> GreptimeDbStandalone { + GreptimeDbStandaloneBuilder::new(test_name) + .with_default_store_type(store_type) + .with_slow_query_threshold(slow_query_threshold) + .build() + .await +} + async fn setup_standalone_instance_with_plugins( test_name: &str, store_type: StorageType, @@ -407,6 +419,20 @@ async fn setup_standalone_instance_with_plugins( .await } +async fn setup_standalone_instance_with_plugins_and_slow_query_threshold( + test_name: &str, + store_type: StorageType, + plugins: Plugins, + slow_query_threshold: std::time::Duration, +) -> GreptimeDbStandalone { + GreptimeDbStandaloneBuilder::new(test_name) + .with_default_store_type(store_type) + .with_plugin(plugins) + .with_slow_query_threshold(slow_query_threshold) + .build() + .await +} + pub async fn setup_test_http_app(store_type: StorageType, name: &str) -> (Router, TestGuard) { let instance = setup_standalone_instance(name, store_type).await; @@ -433,6 +459,36 @@ pub async fn setup_test_http_app_with_frontend( setup_test_http_app_with_frontend_and_user_provider(store_type, name, None).await } +pub async fn setup_test_http_app_with_frontend_and_slow_query_threshold( + store_type: StorageType, + name: &str, + slow_query_threshold: std::time::Duration, +) -> (Router, TestGuard) { + let instance = + setup_standalone_instance_with_slow_query_threshold(name, store_type, slow_query_threshold) + .await; + + create_test_table(instance.fe_instance(), "demo").await; + + let http_opts = HttpOptions { + addr: format!("127.0.0.1:{}", ports::get_port()), + ..Default::default() + }; + + let http_server = HttpServerBuilder::new(http_opts) + .with_sql_handler(instance.fe_instance().clone()) + .with_log_ingest_handler(instance.fe_instance().clone(), None, None) + .with_logs_handler(instance.fe_instance().clone()) + .with_influxdb_handler(instance.fe_instance().clone()) + .with_otlp_handler(instance.fe_instance().clone(), true) + .with_jaeger_handler(instance.fe_instance().clone()) + .with_greptime_config_options(instance.opts.to_toml().unwrap()) + .build(); + + let app = http_server.build(http_server.make_app()).unwrap(); + (app, instance.guard) +} + pub async fn setup_test_http_app_with_frontend_and_user_provider( store_type: StorageType, name: &str, @@ -644,6 +700,57 @@ pub async fn setup_mysql_server( setup_mysql_server_with_user_provider(store_type, name, None).await } +pub async fn setup_mysql_server_with_slow_query_threshold( + store_type: StorageType, + name: &str, + slow_query_threshold: std::time::Duration, +) -> (TestGuard, Arc>) { + let plugins = Plugins::new(); + let instance = setup_standalone_instance_with_plugins_and_slow_query_threshold( + name, + store_type, + plugins, + slow_query_threshold, + ) + .await; + + let runtime = RuntimeBuilder::default() + .worker_threads(2) + .thread_name("mysql-runtime") + .build() + .unwrap(); + + let fe_mysql_addr = format!("127.0.0.1:{}", ports::get_port()); + + let fe_instance_ref = instance.fe_instance().clone(); + let opts = MysqlOptions { + addr: fe_mysql_addr.clone(), + ..Default::default() + }; + let mut mysql_server = MysqlServer::create_server( + runtime, + Arc::new(MysqlSpawnRef::new(fe_instance_ref, None)), + Arc::new(MysqlSpawnConfig::new( + false, + Arc::new( + ReloadableTlsServerConfig::try_new(opts.tls.clone()) + .expect("Failed to load certificates and keys"), + ), + 0, + opts.reject_no_database.unwrap_or(false), + opts.prepared_stmt_cache_size, + )), + None, + ); + + mysql_server + .start(fe_mysql_addr.parse::().unwrap()) + .await + .unwrap(); + + (instance.guard, Arc::new(mysql_server)) +} + pub async fn setup_mysql_server_with_user_provider( store_type: StorageType, name: &str, @@ -701,6 +808,51 @@ pub async fn setup_pg_server( setup_pg_server_with_user_provider(store_type, name, None).await } +pub async fn setup_pg_server_with_slow_query_threshold( + store_type: StorageType, + name: &str, + slow_query_threshold: std::time::Duration, +) -> (TestGuard, Arc>) { + let instance = + setup_standalone_instance_with_slow_query_threshold(name, store_type, slow_query_threshold) + .await; + + let runtime = RuntimeBuilder::default() + .worker_threads(2) + .thread_name("pg-runtime") + .build() + .unwrap(); + + let fe_pg_addr = format!("127.0.0.1:{}", ports::get_port()); + + let fe_instance_ref = instance.fe_instance().clone(); + let opts = PostgresOptions { + addr: fe_pg_addr.clone(), + ..Default::default() + }; + let tls_server_config = Arc::new( + ReloadableTlsServerConfig::try_new(opts.tls.clone()) + .expect("Failed to load certificates and keys"), + ); + + let mut pg_server = Box::new(PostgresServer::new( + fe_instance_ref, + opts.tls.should_force_tls(), + tls_server_config, + 0, + runtime, + None, + None, + )); + + pg_server + .start(fe_pg_addr.parse::().unwrap()) + .await + .unwrap(); + + (instance.guard, Arc::new(pg_server)) +} + pub async fn setup_pg_server_with_user_provider( store_type: StorageType, name: &str, diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index ca0a7fda34..138d5d9ab9 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -60,6 +60,7 @@ use servers::request_memory_limiter::ServerMemoryLimiter; use table::table_name::TableName; use tests_integration::test_util::{ StorageType, setup_test_http_app, setup_test_http_app_with_frontend, + setup_test_http_app_with_frontend_and_slow_query_threshold, setup_test_http_app_with_frontend_and_user_provider, setup_test_prom_app_with_frontend, }; use urlencoding::encode; @@ -665,24 +666,29 @@ async fn test_sql_format_api() { } pub async fn test_http_sql_slow_query(store_type: StorageType) { - let (app, mut guard) = setup_test_http_app_with_frontend(store_type, "sql_api").await; + let (app, mut guard) = setup_test_http_app_with_frontend_and_slow_query_threshold( + store_type, + "sql_api", + Duration::from_millis(100), + ) + .await; let client = TestClient::new(app).await; - let slow_query = "SELECT count(*) FROM generate_series(1, 1000000000)"; + let slow_query = "SELECT count(*) FROM generate_series(1, 50000000)"; let encoded_slow_query = encode(slow_query); let query_params = format!("/v1/sql?sql={encoded_slow_query}"); let res = client.get(&query_params).send().await; assert_eq!(res.status(), StatusCode::OK); - // Wait for the slow query to be recorded. - tokio::time::sleep(Duration::from_secs(5)).await; - let table = format!("{}.{}", DEFAULT_PRIVATE_SCHEMA_NAME, SLOW_QUERY_TABLE_NAME); - let query = format!("SELECT {} FROM {table}", SLOW_QUERY_TABLE_QUERY_COLUMN_NAME); + let query = format!( + "SELECT {} FROM {table} WHERE {} = '{slow_query}'", + SLOW_QUERY_TABLE_QUERY_COLUMN_NAME, SLOW_QUERY_TABLE_QUERY_COLUMN_NAME + ); let expected = format!(r#"[["{}"]]"#, slow_query); - validate_data("test_http_sql_slow_query", &client, &query, &expected).await; + wait_for_data(&client, &query, &expected).await; guard.remove_all().await; } @@ -7297,6 +7303,32 @@ async fn validate_data(test_name: &str, client: &TestClient, sql: &str, expected ); } +async fn wait_for_data(client: &TestClient, sql: &str, expected: &str) { + tokio::time::timeout(Duration::from_secs(10), async { + let encoded_sql = encode(sql); + loop { + let res = client + .get(format!("/v1/sql?sql={encoded_sql}").as_str()) + .send() + .await; + if res.status() != StatusCode::OK { + tokio::time::sleep(Duration::from_millis(50)).await; + continue; + } + let resp = res.text().await; + let v = get_rows_from_output(&resp); + + if expected == v { + break; + } + + tokio::time::sleep(Duration::from_millis(50)).await; + } + }) + .await + .unwrap(); +} + async fn send_req( client: &TestClient, headers: Vec<(HeaderName, HeaderValue)>, diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index 41bbf2ce4d..28f8535ec2 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::time::Duration; use auth::user_provider_from_option; use chrono::{DateTime, NaiveDate, NaiveDateTime, SecondsFormat, Utc}; @@ -27,8 +28,9 @@ use sqlx::postgres::{PgDatabaseError, PgPoolOptions}; use sqlx::types::Decimal; use sqlx::{Connection, Executor, Row}; use tests_integration::test_util::{ - StorageType, setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server, - setup_pg_server_with_user_provider, + StorageType, setup_mysql_server, setup_mysql_server_with_slow_query_threshold, + setup_mysql_server_with_user_provider, setup_pg_server, + setup_pg_server_with_slow_query_threshold, setup_pg_server_with_user_provider, }; use tokio_postgres::{Client, NoTls, SimpleQueryMessage}; @@ -693,8 +695,12 @@ pub async fn test_postgres_crud(store_type: StorageType) { pub async fn test_mysql_slow_query(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (mut guard, fe_mysql_server) = - setup_mysql_server(store_type, "test_mysql_slow_query").await; + let (mut guard, fe_mysql_server) = setup_mysql_server_with_slow_query_threshold( + store_type, + "test_mysql_slow_query", + Duration::from_millis(100), + ) + .await; let addr = fe_mysql_server.bind_addr().unwrap().to_string(); let pool = MySqlPoolOptions::new() @@ -703,29 +709,38 @@ pub async fn test_mysql_slow_query(store_type: StorageType) { .await .unwrap(); - // The slow query will run at least longer than 1s. - let slow_query = "SELECT count(*) FROM generate_series(1, 1000000000)"; + // The slow query should run longer than the configured threshold. + let slow_query = "SELECT count(*) FROM generate_series(1, 50000000)"; // Simulate a slow query. sqlx::query(slow_query).fetch_all(&pool).await.unwrap(); - // Wait for the slow query to be recorded. - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - let table = format!("{}.{}", DEFAULT_PRIVATE_SCHEMA_NAME, SLOW_QUERY_TABLE_NAME); let query = format!( - "SELECT {}, {}, {}, {} FROM {table}", + "SELECT {}, {}, {}, {} FROM {table} WHERE {} = ?", SLOW_QUERY_TABLE_COST_COLUMN_NAME, SLOW_QUERY_TABLE_THRESHOLD_COLUMN_NAME, SLOW_QUERY_TABLE_QUERY_COLUMN_NAME, - SLOW_QUERY_TABLE_IS_PROMQL_COLUMN_NAME + SLOW_QUERY_TABLE_IS_PROMQL_COLUMN_NAME, + SLOW_QUERY_TABLE_QUERY_COLUMN_NAME, ); - let rows = sqlx::query(&query).fetch_all(&pool).await.unwrap(); - assert_eq!(rows.len(), 1); + let row = tokio::time::timeout(Duration::from_secs(10), async { + loop { + if let Ok(Some(row)) = sqlx::query(&query) + .bind(slow_query) + .fetch_optional(&pool) + .await + { + break row; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + }) + .await + .unwrap(); // Check the results. - let row = &rows[0]; let cost: u64 = row.get(0); let threshold: u64 = row.get(1); let query: String = row.get(2); @@ -810,7 +825,12 @@ pub async fn test_postgres_bytea(store_type: StorageType) { } pub async fn test_postgres_slow_query(store_type: StorageType) { - let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_slow_query").await; + let (mut guard, fe_pg_server) = setup_pg_server_with_slow_query_threshold( + store_type, + "test_postgres_slow_query", + Duration::from_millis(100), + ) + .await; let addr = fe_pg_server.bind_addr().unwrap().to_string(); let pool = PgPoolOptions::new() @@ -819,23 +839,34 @@ pub async fn test_postgres_slow_query(store_type: StorageType) { .await .unwrap(); - let slow_query = "SELECT count(*) FROM generate_series(1, 1000000000)"; + let slow_query = "SELECT count(*) FROM generate_series(1, 50000000)"; let _ = sqlx::query(slow_query).fetch_all(&pool).await.unwrap(); - // Wait for the slow query to be recorded. - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - let table = format!("{}.{}", DEFAULT_PRIVATE_SCHEMA_NAME, SLOW_QUERY_TABLE_NAME); let query = format!( - "SELECT {}, {}, {}, {} FROM {table}", + "SELECT {}, {}, {}, {} FROM {table} WHERE {} = $1", SLOW_QUERY_TABLE_COST_COLUMN_NAME, SLOW_QUERY_TABLE_THRESHOLD_COLUMN_NAME, SLOW_QUERY_TABLE_QUERY_COLUMN_NAME, - SLOW_QUERY_TABLE_IS_PROMQL_COLUMN_NAME + SLOW_QUERY_TABLE_IS_PROMQL_COLUMN_NAME, + SLOW_QUERY_TABLE_QUERY_COLUMN_NAME, ); - let rows = sqlx::query(&query).fetch_all(&pool).await.unwrap(); - assert_eq!(rows.len(), 1); - let row = &rows[0]; + + let row = tokio::time::timeout(Duration::from_secs(10), async { + loop { + if let Ok(Some(row)) = sqlx::query(&query) + .bind(slow_query) + .fetch_optional(&pool) + .await + { + break row; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + }) + .await + .unwrap(); + let cost: Decimal = row.get(0); let threshold: Decimal = row.get(1); let query: String = row.get(2);