From 40e9ce90a71f7b545039cf1feda90e6729b1749e Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Sat, 11 Oct 2025 14:34:17 +0800 Subject: [PATCH] refactor: restructure sqlness to support multiple envs and extract common utils (#7066) * refactor: restructure sqlness to support multiple envs and extract common utils Signed-off-by: WenyXu * chore(ci): update sqlness cmd Signed-off-by: WenyXu * chore: add comments Signed-off-by: WenyXu * fix: error fmt Signed-off-by: WenyXu * fix: only reconnect mysql and pg client Signed-off-by: WenyXu * chore: apply suggestions Signed-off-by: WenyXu --------- Signed-off-by: WenyXu --- .github/workflows/develop.yml | 2 +- Cargo.lock | 6 +- Makefile | 2 +- tests/README.md | 4 +- tests/compat/util.sh | 8 +- tests/runner/src/client.rs | 235 +++++++ tests/runner/src/cmd.rs | 54 ++ tests/runner/src/cmd/bare.rs | 206 +++++++ tests/runner/src/cmd/kube.rs | 95 +++ tests/runner/src/env.rs | 1023 +------------------------------ tests/runner/src/env/bare.rs | 751 +++++++++++++++++++++++ tests/runner/src/env/kube.rs | 204 ++++++ tests/runner/src/formatter.rs | 216 +++++++ tests/runner/src/main.rs | 210 +------ tests/runner/src/server_mode.rs | 5 +- tests/runner/src/util.rs | 25 + 16 files changed, 1812 insertions(+), 1234 deletions(-) create mode 100644 tests/runner/src/client.rs create mode 100644 tests/runner/src/cmd.rs create mode 100644 tests/runner/src/cmd/bare.rs create mode 100644 tests/runner/src/cmd/kube.rs create mode 100644 tests/runner/src/env/bare.rs create mode 100644 tests/runner/src/env/kube.rs create mode 100644 tests/runner/src/formatter.rs diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index 37edc39242..8dde424c8e 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -632,7 +632,7 @@ jobs: - name: Unzip binaries run: tar -xvf ./bins.tar.gz - name: Run sqlness - run: RUST_BACKTRACE=1 ./bins/sqlness-runner ${{ matrix.mode.opts }} -c ./tests/cases --bins-dir ./bins --preserve-state + run: RUST_BACKTRACE=1 ./bins/sqlness-runner bare ${{ matrix.mode.opts }} -c ./tests/cases --bins-dir ./bins --preserve-state - name: Upload sqlness logs if: failure() uses: actions/upload-artifact@v4 diff --git a/Cargo.lock b/Cargo.lock index 2607814b16..6e88493e92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5825,9 +5825,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.14" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ "base64 0.22.1", "bytes", @@ -5841,7 +5841,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.0", "tokio", "tower-service", "tracing", diff --git a/Makefile b/Makefile index 04e022d13e..d94426554e 100644 --- a/Makefile +++ b/Makefile @@ -169,7 +169,7 @@ nextest: ## Install nextest tools. .PHONY: sqlness-test sqlness-test: ## Run sqlness test. - cargo sqlness ${SQLNESS_OPTS} + cargo sqlness bare ${SQLNESS_OPTS} RUNS ?= 1 FUZZ_TARGET ?= fuzz_alter_table diff --git a/tests/README.md b/tests/README.md index ae2d42d9c5..98435399ad 100644 --- a/tests/README.md +++ b/tests/README.md @@ -32,7 +32,7 @@ To run test with kafka, you need to pass the option `-w kafka`. If no other opti Otherwise, you can additionally pass the your existing kafka environment to sqlness with `-k` option. E.g.: ```shell -cargo sqlness -w kafka -k localhost:9092 +cargo sqlness bare -w kafka -k localhost:9092 ``` In this case, sqlness will not start its own kafka cluster and the one you provided instead. @@ -42,7 +42,7 @@ In this case, sqlness will not start its own kafka cluster and the one you provi Unlike other tests, this harness is in a binary target form. You can run it with: ```shell -cargo sqlness +cargo sqlness bare ``` It automatically finishes the following procedures: compile `GreptimeDB`, start it, grab tests and feed it to diff --git a/tests/compat/util.sh b/tests/compat/util.sh index 3158e7ecd7..94ec7ea3a2 100755 --- a/tests/compat/util.sh +++ b/tests/compat/util.sh @@ -64,19 +64,19 @@ run_test() { then echo " === Running forward compat test ..." echo " === Run test: write with current GreptimeDB" - $runner --bins-dir $(dirname $bin_new) --case-dir $write_case_dir + $runner bare --bins-dir $(dirname $bin_new) --case-dir $write_case_dir else echo " === Running backward compat test ..." echo " === Run test: write with old GreptimeDB" - $runner --bins-dir $(dirname $bin_old) --case-dir $write_case_dir + $runner bare --bins-dir $(dirname $bin_old) --case-dir $write_case_dir fi if [ "$forward" == 'forward' ] then echo " === Run test: read with old GreptimeDB" - $runner --bins-dir $(dirname $bin_old) --case-dir $read_case_dir + $runner bare --bins-dir $(dirname $bin_old) --case-dir $read_case_dir else echo " === Run test: read with current GreptimeDB" - $runner --bins-dir $(dirname $bin_new) --case-dir $read_case_dir + $runner bare --bins-dir $(dirname $bin_new) --case-dir $read_case_dir fi } diff --git a/tests/runner/src/client.rs b/tests/runner/src/client.rs new file mode 100644 index 0000000000..163e827e53 --- /dev/null +++ b/tests/runner/src/client.rs @@ -0,0 +1,235 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::SocketAddr; +use std::time::Duration; + +use client::error::ServerSnafu; +use client::{ + Client, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, Database, OutputData, RecordBatches, +}; +use common_error::ext::ErrorExt; +use common_query::Output; +use mysql::prelude::Queryable; +use mysql::{Conn as MySqlClient, Row as MySqlRow}; +use tokio_postgres::{Client as PgClient, SimpleQueryMessage as PgRow}; + +use crate::util::retry_with_backoff; + +/// A client that can connect to GreptimeDB using multiple protocols. +pub struct MultiProtocolClient { + grpc_client: Database, + pg_client: PgClient, + mysql_client: MySqlClient, +} + +/// The result of a MySQL query. +pub enum MysqlSqlResult { + AffectedRows(u64), + Rows(Vec), +} + +impl MultiProtocolClient { + /// Connect to the GreptimeDB server using multiple protocols. + /// + /// # Arguments + /// + /// * `grpc_server_addr` - The address of the GreptimeDB server. + /// * `pg_server_addr` - The address of the Postgres server. + /// * `mysql_server_addr` - The address of the MySQL server. + /// + /// # Returns + /// + /// A `MultiProtocolClient` instance. + /// + /// # Panics + /// + /// Panics if the server addresses are invalid or the connection fails. + pub async fn connect( + grpc_server_addr: &str, + pg_server_addr: &str, + mysql_server_addr: &str, + ) -> MultiProtocolClient { + let grpc_client = Database::new( + DEFAULT_CATALOG_NAME, + DEFAULT_SCHEMA_NAME, + Client::with_urls(vec![grpc_server_addr]), + ); + let pg_client = create_postgres_client(pg_server_addr).await; + let mysql_client = create_mysql_client(mysql_server_addr).await; + MultiProtocolClient { + grpc_client, + pg_client, + mysql_client, + } + } + + /// Reconnect the MySQL client. + pub async fn reconnect_mysql_client(&mut self, mysql_server_addr: &str) { + self.mysql_client = create_mysql_client(mysql_server_addr).await; + } + + /// Reconnect the Postgres client. + pub async fn reconnect_pg_client(&mut self, pg_server_addr: &str) { + self.pg_client = create_postgres_client(pg_server_addr).await; + } + + /// Execute a query on the Postgres server. + pub async fn postgres_query(&mut self, query: &str) -> Result, String> { + match self.pg_client.simple_query(query).await { + Ok(rows) => Ok(rows), + Err(e) => Err(format!("Failed to execute query, encountered: {:?}", e)), + } + } + + /// Execute a query on the MySQL server. + pub async fn mysql_query(&mut self, query: &str) -> Result { + let result = self.mysql_client.query_iter(query); + match result { + Ok(result) => { + let mut rows = vec![]; + let affected_rows = result.affected_rows(); + for row in result { + match row { + Ok(r) => rows.push(r), + Err(e) => { + return Err(format!("Failed to parse query result, err: {:?}", e)); + } + } + } + + if rows.is_empty() { + Ok(MysqlSqlResult::AffectedRows(affected_rows)) + } else { + Ok(MysqlSqlResult::Rows(rows)) + } + } + Err(e) => Err(format!("Failed to execute query, err: {:?}", e)), + } + } + + /// Execute a query on the GreptimeDB server. + pub async fn grpc_query(&mut self, query: &str) -> Result { + let query_str = query.trim().to_lowercase(); + if query_str.starts_with("use ") { + // use [db] + let database = query + .split_ascii_whitespace() + .nth(1) + .expect("Illegal `USE` statement: expecting a database.") + .trim_end_matches(';'); + self.grpc_client.set_schema(database); + Ok(Output::new_with_affected_rows(0)) + } else if query_str.starts_with("set time_zone") + || query_str.starts_with("set session time_zone") + || query_str.starts_with("set local time_zone") + { + // set time_zone='xxx' + let timezone = query + .split('=') + .nth(1) + .expect("Illegal `SET TIMEZONE` statement: expecting a timezone expr.") + .trim() + .strip_prefix('\'') + .unwrap() + .strip_suffix("';") + .unwrap(); + + self.grpc_client.set_timezone(timezone); + Ok(Output::new_with_affected_rows(0)) + } else { + let mut result = self.grpc_client.sql(&query).await; + if let Ok(Output { + data: OutputData::Stream(stream), + .. + }) = result + { + match RecordBatches::try_collect(stream).await { + Ok(recordbatches) => { + result = Ok(Output::new_with_record_batches(recordbatches)); + } + Err(e) => { + let status_code = e.status_code(); + let msg = e.output_msg(); + result = ServerSnafu { + code: status_code, + msg, + } + .fail(); + } + } + } + + result + } + } +} + +/// Create a Postgres client with retry. +/// +/// # Panics +/// +/// Panics if the Postgres server address is invalid or the connection fails. +async fn create_postgres_client(pg_server_addr: &str) -> PgClient { + let sockaddr: SocketAddr = pg_server_addr.parse().unwrap_or_else(|_| { + panic!("Failed to parse the Postgres server address {pg_server_addr}. Please check if the address is in the format of `ip:port`.") + }); + let mut config = tokio_postgres::config::Config::new(); + config.host(sockaddr.ip().to_string()); + config.port(sockaddr.port()); + config.dbname(DEFAULT_SCHEMA_NAME); + + retry_with_backoff( + || async { + config + .connect(tokio_postgres::NoTls) + .await + .map(|(client, conn)| { + tokio::spawn(conn); + client + }) + }, + 3, + Duration::from_millis(500), + ) + .await + .unwrap_or_else(|_| { + panic!("Failed to connect to Postgres server. Please check if the server is running.") + }) +} + +/// Create a MySQL client with retry. +/// +/// # Panics +/// +/// Panics if the MySQL server address is invalid or the connection fails. +async fn create_mysql_client(mysql_server_addr: &str) -> MySqlClient { + let sockaddr: SocketAddr = mysql_server_addr.parse().unwrap_or_else(|_| { + panic!("Failed to parse the MySQL server address {mysql_server_addr}. Please check if the address is in the format of `ip:port`.") + }); + let ops = mysql::OptsBuilder::new() + .ip_or_hostname(Some(sockaddr.ip().to_string())) + .tcp_port(sockaddr.port()) + .db_name(Some(DEFAULT_SCHEMA_NAME)); + + retry_with_backoff( + || async { mysql::Conn::new(ops.clone()) }, + 3, + Duration::from_millis(500), + ) + .await + .unwrap_or_else(|_| { + panic!("Failed to connect to MySQL server. Please check if the server is running.") + }) +} diff --git a/tests/runner/src/cmd.rs b/tests/runner/src/cmd.rs new file mode 100644 index 0000000000..f7aaacfc73 --- /dev/null +++ b/tests/runner/src/cmd.rs @@ -0,0 +1,54 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub(crate) mod bare; +pub(crate) mod kube; + +use std::path::PathBuf; + +use bare::BareCommand; +use clap::Parser; +use kube::KubeCommand; + +#[derive(Parser)] +#[clap(author, version, about, long_about = None)] +pub struct Command { + #[clap(subcommand)] + pub subcmd: SubCommand, +} + +#[derive(Parser)] +pub enum SubCommand { + Bare(BareCommand), + Kube(KubeCommand), +} + +#[derive(Debug, Parser)] +pub struct SqlnessConfig { + /// Directory of test cases + #[clap(short, long)] + pub case_dir: Option, + + /// Fail this run as soon as one case fails if true + #[arg(short, long, default_value = "false")] + pub fail_fast: bool, + + /// Environment Configuration File + #[clap(short, long, default_value = "config.toml")] + pub env_config_file: String, + + /// Name of test cases to run. Accept as a regexp. + #[clap(short, long, default_value = ".*")] + pub test_filter: String, +} diff --git a/tests/runner/src/cmd/bare.rs b/tests/runner/src/cmd/bare.rs new file mode 100644 index 0000000000..bc9f00bf0d --- /dev/null +++ b/tests/runner/src/cmd/bare.rs @@ -0,0 +1,206 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::PathBuf; +use std::sync::Arc; + +use clap::{Parser, ValueEnum}; +use sqlness::interceptor::Registry; +use sqlness::{ConfigBuilder, Runner}; + +use crate::cmd::SqlnessConfig; +use crate::env::bare::{Env, ServiceProvider, StoreConfig, WalConfig}; +use crate::{protocol_interceptor, util}; + +#[derive(ValueEnum, Debug, Clone)] +#[clap(rename_all = "snake_case")] +enum Wal { + RaftEngine, + Kafka, +} + +// add a group to ensure that all server addresses are set together +#[derive(clap::Args, Debug, Clone, Default)] +pub(crate) struct ServerAddr { + /// Address of the grpc server. + #[clap(short, long)] + pub(crate) server_addr: Option, + + /// Address of the postgres server. Must be set if server_addr is set. + #[clap(short, long, requires = "server_addr")] + pub(crate) pg_server_addr: Option, + + /// Address of the mysql server. Must be set if server_addr is set. + #[clap(short, long, requires = "server_addr")] + pub(crate) mysql_server_addr: Option, +} + +#[derive(Debug, Parser)] +/// Run sqlness tests in bare mode. +pub struct BareCommand { + #[clap(flatten)] + config: SqlnessConfig, + + /// Addresses of the server. + #[command(flatten)] + server_addr: ServerAddr, + + /// The type of Wal. + #[clap(short, long, default_value = "raft_engine")] + wal: Wal, + + /// The kafka wal broker endpoints. This config will suppress sqlness runner + /// from starting a kafka cluster, and use the given endpoint as kafka backend. + #[clap(short, long)] + kafka_wal_broker_endpoints: Option, + + /// The path to the directory where GreptimeDB's binaries resides. + /// If not set, sqlness will build GreptimeDB on the fly. + #[clap(long)] + bins_dir: Option, + + /// Preserve persistent state in the temporary directory. + /// This may affect future test runs. + #[clap(long)] + preserve_state: bool, + + /// Pull Different versions of GreptimeDB on need. + #[clap(long, default_value = "true")] + pull_version_on_need: bool, + + /// The store addresses for metadata, if empty, will use memory store. + #[clap(long)] + store_addrs: Vec, + + /// Whether to setup etcd, by default it is false. + #[clap(long, default_value = "false")] + setup_etcd: bool, + + /// Whether to setup pg, by default it is false. + #[clap(long, default_missing_value = "", num_args(0..=1))] + setup_pg: Option, + + /// Whether to setup mysql, by default it is false. + #[clap(long, default_missing_value = "", num_args(0..=1))] + setup_mysql: Option, + + /// The number of jobs to run in parallel. Default to half of the cores. + #[clap(short, long, default_value = "0")] + jobs: usize, + + /// Extra command line arguments when starting GreptimeDB binaries. + #[clap(long)] + extra_args: Vec, +} + +impl BareCommand { + pub async fn run(mut self) { + let temp_dir = tempfile::Builder::new() + .prefix("sqlness") + .tempdir() + .unwrap(); + let sqlness_home = temp_dir.keep(); + + let mut interceptor_registry: Registry = Default::default(); + interceptor_registry.register( + protocol_interceptor::PREFIX, + Arc::new(protocol_interceptor::ProtocolInterceptorFactory), + ); + + if let Some(d) = &self.config.case_dir + && !d.is_dir() + { + panic!("{} is not a directory", d.display()); + } + if self.jobs == 0 { + self.jobs = num_cpus::get() / 2; + } + + // normalize parallelism to 1 if any of the following conditions are met: + // Note: parallelism in pg and mysql is possible, but need configuration. + if self.server_addr.server_addr.is_some() + || self.setup_etcd + || self.setup_pg.is_some() + || self.setup_mysql.is_some() + || self.kafka_wal_broker_endpoints.is_some() + || self.config.test_filter != ".*" + { + self.jobs = 1; + println!( + "Normalizing parallelism to 1 due to server addresses, etcd/pg/mysql setup, or test filter usage" + ); + } + + let config = ConfigBuilder::default() + .case_dir(util::get_case_dir(self.config.case_dir)) + .fail_fast(self.config.fail_fast) + .test_filter(self.config.test_filter) + .follow_links(true) + .env_config_file(self.config.env_config_file) + .interceptor_registry(interceptor_registry) + .parallelism(self.jobs) + .build() + .unwrap(); + + let wal = match self.wal { + Wal::RaftEngine => WalConfig::RaftEngine, + Wal::Kafka => WalConfig::Kafka { + needs_kafka_cluster: self.kafka_wal_broker_endpoints.is_none(), + broker_endpoints: self + .kafka_wal_broker_endpoints + .map(|s| s.split(',').map(|s| s.to_string()).collect()) + // otherwise default to the same port in `kafka-cluster.yml` + .unwrap_or(vec!["127.0.0.1:9092".to_string()]), + }, + }; + + let store = StoreConfig { + store_addrs: self.store_addrs.clone(), + setup_etcd: self.setup_etcd, + setup_pg: self.setup_pg, + setup_mysql: self.setup_mysql, + }; + + let runner = Runner::new( + config, + Env::new( + sqlness_home.clone(), + self.server_addr, + wal, + self.pull_version_on_need, + self.bins_dir, + store, + self.extra_args, + ), + ); + match runner.run().await { + Ok(_) => println!("\x1b[32mAll sqlness tests passed!\x1b[0m"), + Err(e) => { + println!("\x1b[31mTest failed: {}\x1b[0m", e); + std::process::exit(1); + } + } + + // clean up and exit + if !self.preserve_state { + if self.setup_etcd { + println!("Stopping etcd"); + util::stop_rm_etcd(); + } + // TODO(weny): remove postgre and mysql containers + println!("Removing state in {:?}", sqlness_home); + tokio::fs::remove_dir_all(sqlness_home).await.unwrap(); + } + } +} diff --git a/tests/runner/src/cmd/kube.rs b/tests/runner/src/cmd/kube.rs new file mode 100644 index 0000000000..650ed690b2 --- /dev/null +++ b/tests/runner/src/cmd/kube.rs @@ -0,0 +1,95 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use clap::Parser; +use sqlness::interceptor::Registry; +use sqlness::{ConfigBuilder, Runner}; + +use crate::cmd::SqlnessConfig; +use crate::env::kube::{Env, NaiveResourcesManager}; +use crate::{protocol_interceptor, util}; + +#[derive(Debug, Parser)] +/// Run sqlness tests in kube mode. +pub struct KubeCommand { + #[clap(flatten)] + config: SqlnessConfig, + + /// Whether to delete the namespace on stop. + #[clap(long, default_value = "false")] + delete_namespace_on_stop: bool, + + /// Address of the grpc server. + #[clap(short, long)] + server_addr: String, + + /// Address of the postgres server. Must be set if server_addr is set. + #[clap(short, long)] + pg_server_addr: String, + + /// Address of the mysql server. Must be set if server_addr is set. + #[clap(short, long)] + mysql_server_addr: String, + + /// The namespace of the GreptimeDB. + #[clap(short, long)] + namespace: String, +} + +impl KubeCommand { + pub async fn run(self) { + let mut interceptor_registry: Registry = Default::default(); + interceptor_registry.register( + protocol_interceptor::PREFIX, + Arc::new(protocol_interceptor::ProtocolInterceptorFactory), + ); + + if let Some(d) = &self.config.case_dir + && !d.is_dir() + { + panic!("{} is not a directory", d.display()); + } + + let config = ConfigBuilder::default() + .case_dir(util::get_case_dir(self.config.case_dir)) + .fail_fast(self.config.fail_fast) + .test_filter(self.config.test_filter) + .follow_links(true) + .env_config_file(self.config.env_config_file) + .interceptor_registry(interceptor_registry) + .build() + .unwrap(); + + let runner = Runner::new( + config, + Env { + delete_namespace_on_stop: self.delete_namespace_on_stop, + server_addr: self.server_addr, + pg_server_addr: self.pg_server_addr, + mysql_server_addr: self.mysql_server_addr, + database_manager: Arc::new(()), + resources_manager: Arc::new(NaiveResourcesManager::new(self.namespace)), + }, + ); + match runner.run().await { + Ok(_) => println!("\x1b[32mAll sqlness tests passed!\x1b[0m"), + Err(e) => { + println!("\x1b[31mTest failed: {}\x1b[0m", e); + std::process::exit(1); + } + } + } +} diff --git a/tests/runner/src/env.rs b/tests/runner/src/env.rs index 1a329f6756..e9a9d05fe8 100644 --- a/tests/runner/src/env.rs +++ b/tests/runner/src/env.rs @@ -12,1024 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; -use std::collections::HashMap; -use std::fmt::Display; -use std::fs::OpenOptions; -use std::io; -use std::io::Write; -use std::net::SocketAddr; -use std::path::{Path, PathBuf}; -use std::process::{Child, Command}; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, Mutex}; -use std::time::Duration; - -use async_trait::async_trait; -use client::error::ServerSnafu; -use client::{ - Client, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, Database as DB, Error as ClientError, -}; -use common_error::ext::ErrorExt; -use common_query::{Output, OutputData}; -use common_recordbatch::RecordBatches; -use datatypes::data_type::ConcreteDataType; -use datatypes::scalars::ScalarVectorBuilder; -use datatypes::schema::{ColumnSchema, Schema}; -use datatypes::vectors::{StringVectorBuilder, VectorRef}; -use mysql::prelude::Queryable; -use mysql::{Conn as MySqlClient, Row as MySqlRow}; -use sqlness::{Database, EnvController, QueryContext}; -use tokio::sync::Mutex as TokioMutex; -use tokio_postgres::{Client as PgClient, SimpleQueryMessage as PgRow}; - -use crate::protocol_interceptor::{MYSQL, PROTOCOL_KEY}; -use crate::server_mode::ServerMode; -use crate::util::{PROGRAM, get_workspace_root, maybe_pull_binary}; -use crate::{ServerAddr, util}; - -// standalone mode -const SERVER_MODE_STANDALONE_IDX: usize = 0; -// distributed mode -const SERVER_MODE_METASRV_IDX: usize = 0; -const SERVER_MODE_DATANODE_START_IDX: usize = 1; -const SERVER_MODE_FRONTEND_IDX: usize = 4; -const SERVER_MODE_FLOWNODE_IDX: usize = 5; - -#[derive(Clone)] -pub enum WalConfig { - RaftEngine, - Kafka { - /// Indicates whether the runner needs to start a kafka cluster - /// (it might be available in the external system environment). - needs_kafka_cluster: bool, - broker_endpoints: Vec, - }, -} - -#[derive(Debug, Clone)] -pub(crate) enum ServiceProvider { - Create, - External(String), -} - -impl From<&str> for ServiceProvider { - fn from(value: &str) -> Self { - if value.is_empty() { - Self::Create - } else { - Self::External(value.to_string()) - } - } -} - -#[derive(Clone)] -pub struct StoreConfig { - pub store_addrs: Vec, - pub setup_etcd: bool, - pub(crate) setup_pg: Option, - pub(crate) setup_mysql: Option, -} - -#[derive(Clone)] -pub struct Env { - sqlness_home: PathBuf, - server_addrs: ServerAddr, - wal: WalConfig, - - /// The path to the directory that contains the pre-built GreptimeDB binary. - /// When running in CI, this is expected to be set. - /// If not set, this runner will build the GreptimeDB binary itself when needed, and set this field by then. - bins_dir: Arc>>, - /// The path to the directory that contains the old pre-built GreptimeDB binaries. - versioned_bins_dirs: Arc>>, - /// Pull different versions of GreptimeDB on need. - pull_version_on_need: bool, - /// Store address for metasrv metadata - store_config: StoreConfig, - /// Extra command line arguments when starting GreptimeDB binaries. - extra_args: Vec, -} - -#[async_trait] -impl EnvController for Env { - type DB = GreptimeDB; - - async fn start(&self, mode: &str, id: usize, _config: Option<&Path>) -> Self::DB { - if self.server_addrs.server_addr.is_some() && id > 0 { - panic!("Parallel test mode is not supported when server address is already set."); - } - - unsafe { - std::env::set_var("SQLNESS_HOME", self.sqlness_home.display().to_string()); - } - match mode { - "standalone" => self.start_standalone(id).await, - "distributed" => self.start_distributed(id).await, - _ => panic!("Unexpected mode: {mode}"), - } - } - - /// Stop one [`Database`]. - async fn stop(&self, _mode: &str, mut database: Self::DB) { - database.stop(); - } -} - -impl Env { - pub fn new( - data_home: PathBuf, - server_addrs: ServerAddr, - wal: WalConfig, - pull_version_on_need: bool, - bins_dir: Option, - store_config: StoreConfig, - extra_args: Vec, - ) -> Self { - Self { - sqlness_home: data_home, - server_addrs, - wal, - pull_version_on_need, - bins_dir: Arc::new(Mutex::new(bins_dir.clone())), - versioned_bins_dirs: Arc::new(Mutex::new(HashMap::from_iter([( - "latest".to_string(), - bins_dir.clone().unwrap_or(util::get_binary_dir("debug")), - )]))), - store_config, - extra_args, - } - } - - async fn start_standalone(&self, id: usize) -> GreptimeDB { - println!("Starting standalone instance id: {id}"); - - if self.server_addrs.server_addr.is_some() { - self.connect_db(&self.server_addrs, id).await - } else { - self.build_db(); - self.setup_wal(); - let mut db_ctx = GreptimeDBContext::new(self.wal.clone(), self.store_config.clone()); - - let server_mode = ServerMode::random_standalone(); - db_ctx.set_server_mode(server_mode.clone(), SERVER_MODE_STANDALONE_IDX); - let server_addr = server_mode.server_addr().unwrap(); - let server_process = self.start_server(server_mode, &db_ctx, id, true).await; - - let mut greptimedb = self.connect_db(&server_addr, id).await; - greptimedb.server_processes = Some(Arc::new(Mutex::new(vec![server_process]))); - greptimedb.is_standalone = true; - greptimedb.ctx = db_ctx; - - greptimedb - } - } - - async fn start_distributed(&self, id: usize) -> GreptimeDB { - if self.server_addrs.server_addr.is_some() { - self.connect_db(&self.server_addrs, id).await - } else { - self.build_db(); - self.setup_wal(); - self.setup_etcd(); - self.setup_pg(); - self.setup_mysql().await; - let mut db_ctx = GreptimeDBContext::new(self.wal.clone(), self.store_config.clone()); - - // start a distributed GreptimeDB - let meta_server_mode = ServerMode::random_metasrv(); - let metasrv_port = match &meta_server_mode { - ServerMode::Metasrv { - rpc_server_addr, .. - } => rpc_server_addr - .split(':') - .nth(1) - .unwrap() - .parse::() - .unwrap(), - _ => panic!( - "metasrv mode not set, maybe running in remote mode which doesn't support restart?" - ), - }; - db_ctx.set_server_mode(meta_server_mode.clone(), SERVER_MODE_METASRV_IDX); - let meta_server = self.start_server(meta_server_mode, &db_ctx, id, true).await; - - let datanode_1_mode = ServerMode::random_datanode(metasrv_port, 0); - db_ctx.set_server_mode(datanode_1_mode.clone(), SERVER_MODE_DATANODE_START_IDX); - let datanode_1 = self.start_server(datanode_1_mode, &db_ctx, id, true).await; - let datanode_2_mode = ServerMode::random_datanode(metasrv_port, 1); - db_ctx.set_server_mode(datanode_2_mode.clone(), SERVER_MODE_DATANODE_START_IDX + 1); - let datanode_2 = self.start_server(datanode_2_mode, &db_ctx, id, true).await; - let datanode_3_mode = ServerMode::random_datanode(metasrv_port, 2); - db_ctx.set_server_mode(datanode_3_mode.clone(), SERVER_MODE_DATANODE_START_IDX + 2); - let datanode_3 = self.start_server(datanode_3_mode, &db_ctx, id, true).await; - - let frontend_mode = ServerMode::random_frontend(metasrv_port); - let server_addr = frontend_mode.server_addr().unwrap(); - db_ctx.set_server_mode(frontend_mode.clone(), SERVER_MODE_FRONTEND_IDX); - let frontend = self.start_server(frontend_mode, &db_ctx, id, true).await; - - let flownode_mode = ServerMode::random_flownode(metasrv_port, 0); - db_ctx.set_server_mode(flownode_mode.clone(), SERVER_MODE_FLOWNODE_IDX); - let flownode = self.start_server(flownode_mode, &db_ctx, id, true).await; - - let mut greptimedb = self.connect_db(&server_addr, id).await; - - greptimedb.metasrv_process = Some(meta_server).into(); - greptimedb.server_processes = Some(Arc::new(Mutex::new(vec![ - datanode_1, datanode_2, datanode_3, - ]))); - greptimedb.frontend_process = Some(frontend).into(); - greptimedb.flownode_process = Some(flownode).into(); - greptimedb.is_standalone = false; - greptimedb.ctx = db_ctx; - - greptimedb - } - } - - async fn create_pg_client(&self, pg_server_addr: &str) -> PgClient { - let sockaddr: SocketAddr = pg_server_addr.parse().expect( - "Failed to parse the Postgres server address. Please check if the address is in the format of `ip:port`.", - ); - let mut config = tokio_postgres::config::Config::new(); - config.host(sockaddr.ip().to_string()); - config.port(sockaddr.port()); - config.dbname(DEFAULT_SCHEMA_NAME); - - // retry to connect to Postgres server until success - const MAX_RETRY: usize = 3; - let mut backoff = Duration::from_millis(500); - for _ in 0..MAX_RETRY { - if let Ok((pg_client, conn)) = config.connect(tokio_postgres::NoTls).await { - tokio::spawn(conn); - return pg_client; - } - tokio::time::sleep(backoff).await; - backoff *= 2; - } - panic!("Failed to connect to Postgres server. Please check if the server is running."); - } - - async fn create_mysql_client(&self, mysql_server_addr: &str) -> MySqlClient { - let sockaddr: SocketAddr = mysql_server_addr.parse().expect( - "Failed to parse the MySQL server address. Please check if the address is in the format of `ip:port`.", - ); - let ops = mysql::OptsBuilder::new() - .ip_or_hostname(Some(sockaddr.ip().to_string())) - .tcp_port(sockaddr.port()) - .db_name(Some(DEFAULT_SCHEMA_NAME)); - // retry to connect to MySQL server until success - const MAX_RETRY: usize = 3; - let mut backoff = Duration::from_millis(500); - - for _ in 0..MAX_RETRY { - // exponential backoff - if let Ok(client) = mysql::Conn::new(ops.clone()) { - return client; - } - tokio::time::sleep(backoff).await; - backoff *= 2; - } - - panic!("Failed to connect to MySQL server. Please check if the server is running.") - } - - async fn connect_db(&self, server_addr: &ServerAddr, id: usize) -> GreptimeDB { - let grpc_server_addr = server_addr.server_addr.clone().unwrap(); - let pg_server_addr = server_addr.pg_server_addr.clone().unwrap(); - let mysql_server_addr = server_addr.mysql_server_addr.clone().unwrap(); - - let grpc_client = Client::with_urls(vec![grpc_server_addr.clone()]); - let db = DB::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); - let pg_client = self.create_pg_client(&pg_server_addr).await; - let mysql_client = self.create_mysql_client(&mysql_server_addr).await; - - GreptimeDB { - grpc_client: TokioMutex::new(db), - pg_client: TokioMutex::new(pg_client), - mysql_client: TokioMutex::new(mysql_client), - server_processes: None, - metasrv_process: None.into(), - frontend_process: None.into(), - flownode_process: None.into(), - ctx: GreptimeDBContext { - time: 0, - datanode_id: Default::default(), - wal: self.wal.clone(), - store_config: self.store_config.clone(), - server_modes: Vec::new(), - }, - is_standalone: false, - env: self.clone(), - id, - } - } - - fn stop_server(process: &mut Child) { - let _ = process.kill(); - let _ = process.wait(); - } - - async fn start_server( - &self, - mode: ServerMode, - db_ctx: &GreptimeDBContext, - id: usize, - truncate_log: bool, - ) -> Child { - let log_file_name = match mode { - ServerMode::Datanode { node_id, .. } => { - db_ctx.incr_datanode_id(); - format!("greptime-{}-sqlness-datanode-{}.log", id, node_id) - } - ServerMode::Flownode { .. } => format!("greptime-{}-sqlness-flownode.log", id), - ServerMode::Frontend { .. } => format!("greptime-{}-sqlness-frontend.log", id), - ServerMode::Metasrv { .. } => format!("greptime-{}-sqlness-metasrv.log", id), - ServerMode::Standalone { .. } => format!("greptime-{}-sqlness-standalone.log", id), - }; - let stdout_file_name = self.sqlness_home.join(log_file_name).display().to_string(); - - println!("DB instance {id} log file at {stdout_file_name}"); - - let stdout_file = OpenOptions::new() - .create(true) - .write(true) - .truncate(truncate_log) - .append(!truncate_log) - .open(stdout_file_name) - .unwrap(); - - let args = mode.get_args(&self.sqlness_home, self, db_ctx, id); - let check_ip_addrs = mode.check_addrs(); - - for check_ip_addr in &check_ip_addrs { - if util::check_port(check_ip_addr.parse().unwrap(), Duration::from_secs(1)).await { - panic!( - "Port {check_ip_addr} is already in use, please check and retry.", - check_ip_addr = check_ip_addr - ); - } - } - - let program = PROGRAM; - - let bins_dir = self.bins_dir.lock().unwrap().clone().expect( - "GreptimeDB binary is not available. Please pass in the path to the directory that contains the pre-built GreptimeDB binary. Or you may call `self.build_db()` beforehand.", - ); - - let abs_bins_dir = bins_dir - .canonicalize() - .expect("Failed to canonicalize bins_dir"); - - let mut process = Command::new(abs_bins_dir.join(program)) - .current_dir(bins_dir.clone()) - .env("TZ", "UTC") - .args(args) - .stdout(stdout_file) - .spawn() - .unwrap_or_else(|error| { - panic!( - "Failed to start the DB with subcommand {}, Error: {error}, path: {:?}", - mode.name(), - bins_dir.join(program) - ); - }); - - for check_ip_addr in &check_ip_addrs { - if !util::check_port(check_ip_addr.parse().unwrap(), Duration::from_secs(10)).await { - Env::stop_server(&mut process); - panic!("{} doesn't up in 10 seconds, quit.", mode.name()) - } - } - - process - } - - /// stop and restart the server process - async fn restart_server(&self, db: &GreptimeDB, is_full_restart: bool) { - { - if let Some(server_process) = db.server_processes.clone() { - let mut server_processes = server_process.lock().unwrap(); - for server_process in server_processes.iter_mut() { - Env::stop_server(server_process); - } - } - - if is_full_restart { - if let Some(mut metasrv_process) = - db.metasrv_process.lock().expect("poisoned lock").take() - { - Env::stop_server(&mut metasrv_process); - } - if let Some(mut frontend_process) = - db.frontend_process.lock().expect("poisoned lock").take() - { - Env::stop_server(&mut frontend_process); - } - } - - if let Some(mut flownode_process) = - db.flownode_process.lock().expect("poisoned lock").take() - { - Env::stop_server(&mut flownode_process); - } - } - - // check if the server is distributed or standalone - let new_server_processes = if db.is_standalone { - let server_mode = db - .ctx - .get_server_mode(SERVER_MODE_STANDALONE_IDX) - .cloned() - .unwrap(); - let server_addr = server_mode.server_addr().unwrap(); - let new_server_process = self.start_server(server_mode, &db.ctx, db.id, false).await; - - *db.pg_client.lock().await = self - .create_pg_client(&server_addr.pg_server_addr.unwrap()) - .await; - *db.mysql_client.lock().await = self - .create_mysql_client(&server_addr.mysql_server_addr.unwrap()) - .await; - vec![new_server_process] - } else { - db.ctx.reset_datanode_id(); - if is_full_restart { - let metasrv_mode = db - .ctx - .get_server_mode(SERVER_MODE_METASRV_IDX) - .cloned() - .unwrap(); - let metasrv = self.start_server(metasrv_mode, &db.ctx, db.id, false).await; - db.metasrv_process - .lock() - .expect("lock poisoned") - .replace(metasrv); - - // wait for metasrv to start - // since it seems older version of db might take longer to complete election - tokio::time::sleep(Duration::from_secs(5)).await; - } - - let mut processes = vec![]; - for i in 0..3 { - let datanode_mode = db - .ctx - .get_server_mode(SERVER_MODE_DATANODE_START_IDX + i) - .cloned() - .unwrap(); - let new_server_process = self - .start_server(datanode_mode, &db.ctx, db.id, false) - .await; - processes.push(new_server_process); - } - - if is_full_restart { - let frontend_mode = db - .ctx - .get_server_mode(SERVER_MODE_FRONTEND_IDX) - .cloned() - .unwrap(); - let frontend = self - .start_server(frontend_mode, &db.ctx, db.id, false) - .await; - db.frontend_process - .lock() - .expect("lock poisoned") - .replace(frontend); - } - - let flownode_mode = db - .ctx - .get_server_mode(SERVER_MODE_FLOWNODE_IDX) - .cloned() - .unwrap(); - let flownode = self - .start_server(flownode_mode, &db.ctx, db.id, false) - .await; - db.flownode_process - .lock() - .expect("lock poisoned") - .replace(flownode); - - processes - }; - - if let Some(server_processes) = db.server_processes.clone() { - let mut server_processes = server_processes.lock().unwrap(); - *server_processes = new_server_processes; - } - } - - /// Setup kafka wal cluster if needed. The counterpart is in [GreptimeDB::stop]. - fn setup_wal(&self) { - if matches!(self.wal, WalConfig::Kafka { needs_kafka_cluster, .. } if needs_kafka_cluster) { - util::setup_wal(); - } - } - - /// Setup etcd if needed. - fn setup_etcd(&self) { - if self.store_config.setup_etcd { - let client_ports = self - .store_config - .store_addrs - .iter() - .map(|s| s.split(':').nth(1).unwrap().parse::().unwrap()) - .collect::>(); - util::setup_etcd(client_ports, None, None); - } - } - - /// Setup PostgreSql if needed. - fn setup_pg(&self) { - if matches!(self.store_config.setup_pg, Some(ServiceProvider::Create)) { - let client_ports = self - .store_config - .store_addrs - .iter() - .map(|s| s.split(':').nth(1).unwrap().parse::().unwrap()) - .collect::>(); - let client_port = client_ports.first().unwrap_or(&5432); - util::setup_pg(*client_port, None); - } - } - - /// Setup MySql if needed. - async fn setup_mysql(&self) { - if matches!(self.store_config.setup_mysql, Some(ServiceProvider::Create)) { - let client_ports = self - .store_config - .store_addrs - .iter() - .map(|s| s.split(':').nth(1).unwrap().parse::().unwrap()) - .collect::>(); - let client_port = client_ports.first().unwrap_or(&3306); - util::setup_mysql(*client_port, None); - - // Docker of MySQL starts slowly, so we need to wait for a while - tokio::time::sleep(Duration::from_secs(10)).await; - } - } - - /// Build the DB with `cargo build --bin greptime` - fn build_db(&self) { - if self.bins_dir.lock().unwrap().is_some() { - return; - } - - println!("Going to build the DB..."); - let output = Command::new("cargo") - .current_dir(util::get_workspace_root()) - .args([ - "build", - "--bin", - "greptime", - "--features", - "pg_kvbackend,mysql_kvbackend", - ]) - .output() - .expect("Failed to start GreptimeDB"); - if !output.status.success() { - println!("Failed to build GreptimeDB, {}", output.status); - println!("Cargo build stdout:"); - io::stdout().write_all(&output.stdout).unwrap(); - println!("Cargo build stderr:"); - io::stderr().write_all(&output.stderr).unwrap(); - panic!(); - } - - let _ = self - .bins_dir - .lock() - .unwrap() - .insert(util::get_binary_dir("debug")); - } - - pub(crate) fn extra_args(&self) -> &Vec { - &self.extra_args - } -} - -pub struct GreptimeDB { - server_processes: Option>>>, - metasrv_process: Mutex>, - frontend_process: Mutex>, - flownode_process: Mutex>, - grpc_client: TokioMutex, - pg_client: TokioMutex, - mysql_client: TokioMutex, - ctx: GreptimeDBContext, - is_standalone: bool, - env: Env, - id: usize, -} - -impl GreptimeDB { - async fn postgres_query(&self, _ctx: QueryContext, query: String) -> Box { - let client = self.pg_client.lock().await; - match client.simple_query(&query).await { - Ok(rows) => Box::new(PostgresqlFormatter { rows }), - Err(e) => Box::new(format!("Failed to execute query, encountered: {:?}", e)), - } - } - - async fn mysql_query(&self, _ctx: QueryContext, query: String) -> Box { - let mut conn = self.mysql_client.lock().await; - let result = conn.query_iter(query); - Box::new(match result { - Ok(result) => { - let mut rows = vec![]; - let affected_rows = result.affected_rows(); - for row in result { - match row { - Ok(r) => rows.push(r), - Err(e) => { - return Box::new(format!("Failed to parse query result, err: {:?}", e)); - } - } - } - - if rows.is_empty() { - format!("affected_rows: {}", affected_rows) - } else { - format!("{}", MysqlFormatter { rows }) - } - } - Err(e) => format!("Failed to execute query, err: {:?}", e), - }) - } - - async fn grpc_query(&self, _ctx: QueryContext, query: String) -> Box { - let mut client = self.grpc_client.lock().await; - - let query_str = query.trim().to_lowercase(); - - if query_str.starts_with("use ") { - // use [db] - let database = query - .split_ascii_whitespace() - .nth(1) - .expect("Illegal `USE` statement: expecting a database.") - .trim_end_matches(';'); - client.set_schema(database); - Box::new(ResultDisplayer { - result: Ok(Output::new_with_affected_rows(0)), - }) as _ - } else if query_str.starts_with("set time_zone") - || query_str.starts_with("set session time_zone") - || query_str.starts_with("set local time_zone") - { - // set time_zone='xxx' - let timezone = query - .split('=') - .nth(1) - .expect("Illegal `SET TIMEZONE` statement: expecting a timezone expr.") - .trim() - .strip_prefix('\'') - .unwrap() - .strip_suffix("';") - .unwrap(); - - client.set_timezone(timezone); - - Box::new(ResultDisplayer { - result: Ok(Output::new_with_affected_rows(0)), - }) as _ - } else { - let mut result = client.sql(&query).await; - if let Ok(Output { - data: OutputData::Stream(stream), - .. - }) = result - { - match RecordBatches::try_collect(stream).await { - Ok(recordbatches) => { - result = Ok(Output::new_with_record_batches(recordbatches)); - } - Err(e) => { - let status_code = e.status_code(); - let msg = e.output_msg(); - result = ServerSnafu { - code: status_code, - msg, - } - .fail(); - } - } - } - Box::new(ResultDisplayer { result }) as _ - } - } -} - -#[async_trait] -impl Database for GreptimeDB { - async fn query(&self, ctx: QueryContext, query: String) -> Box { - if ctx.context.contains_key("restart") && self.env.server_addrs.server_addr.is_none() { - self.env.restart_server(self, false).await; - } else if let Some(version) = ctx.context.get("version") { - let version_bin_dir = self - .env - .versioned_bins_dirs - .lock() - .expect("lock poison") - .get(version.as_str()) - .cloned(); - - match version_bin_dir { - Some(path) if path.clone().join(PROGRAM).is_file() => { - // use version in versioned_bins_dirs - *self.env.bins_dir.lock().unwrap() = Some(path.clone()); - } - _ => { - // use version in dir files - maybe_pull_binary(version, self.env.pull_version_on_need).await; - let root = get_workspace_root(); - let new_path = PathBuf::from_iter([&root, version]); - *self.env.bins_dir.lock().unwrap() = Some(new_path); - } - } - - self.env.restart_server(self, true).await; - // sleep for a while to wait for the server to fully boot up - tokio::time::sleep(Duration::from_secs(5)).await; - } - - if let Some(protocol) = ctx.context.get(PROTOCOL_KEY) { - // protocol is bound to be either "mysql" or "postgres" - if protocol == MYSQL { - self.mysql_query(ctx, query).await - } else { - self.postgres_query(ctx, query).await - } - } else { - self.grpc_query(ctx, query).await - } - } -} - -impl GreptimeDB { - fn stop(&mut self) { - if let Some(server_processes) = self.server_processes.clone() { - let mut server_processes = server_processes.lock().unwrap(); - for mut server_process in server_processes.drain(..) { - Env::stop_server(&mut server_process); - println!( - "Standalone or Datanode (pid = {}) is stopped", - server_process.id() - ); - } - } - if let Some(mut metasrv) = self - .metasrv_process - .lock() - .expect("someone else panic when holding lock") - .take() - { - Env::stop_server(&mut metasrv); - println!("Metasrv (pid = {}) is stopped", metasrv.id()); - } - if let Some(mut frontend) = self - .frontend_process - .lock() - .expect("someone else panic when holding lock") - .take() - { - Env::stop_server(&mut frontend); - println!("Frontend (pid = {}) is stopped", frontend.id()); - } - if let Some(mut flownode) = self - .flownode_process - .lock() - .expect("someone else panic when holding lock") - .take() - { - Env::stop_server(&mut flownode); - println!("Flownode (pid = {}) is stopped", flownode.id()); - } - if matches!(self.ctx.wal, WalConfig::Kafka { needs_kafka_cluster, .. } if needs_kafka_cluster) - { - util::teardown_wal(); - } - } -} - -impl Drop for GreptimeDB { - fn drop(&mut self) { - if self.env.server_addrs.server_addr.is_none() { - self.stop(); - } - } -} - -pub struct GreptimeDBContext { - /// Start time in millisecond - time: i64, - datanode_id: AtomicU32, - wal: WalConfig, - store_config: StoreConfig, - server_modes: Vec, -} - -impl GreptimeDBContext { - pub fn new(wal: WalConfig, store_config: StoreConfig) -> Self { - Self { - time: common_time::util::current_time_millis(), - datanode_id: AtomicU32::new(0), - wal, - store_config, - server_modes: Vec::new(), - } - } - - pub(crate) fn time(&self) -> i64 { - self.time - } - - pub fn is_raft_engine(&self) -> bool { - matches!(self.wal, WalConfig::RaftEngine) - } - - pub fn kafka_wal_broker_endpoints(&self) -> String { - match &self.wal { - WalConfig::RaftEngine => String::new(), - WalConfig::Kafka { - broker_endpoints, .. - } => serde_json::to_string(&broker_endpoints).unwrap(), - } - } - - fn incr_datanode_id(&self) { - let _ = self.datanode_id.fetch_add(1, Ordering::Relaxed); - } - - fn reset_datanode_id(&self) { - self.datanode_id.store(0, Ordering::Relaxed); - } - - pub(crate) fn store_config(&self) -> StoreConfig { - self.store_config.clone() - } - - fn set_server_mode(&mut self, mode: ServerMode, idx: usize) { - if idx >= self.server_modes.len() { - self.server_modes.resize(idx + 1, mode.clone()); - } - self.server_modes[idx] = mode; - } - - fn get_server_mode(&self, idx: usize) -> Option<&ServerMode> { - self.server_modes.get(idx) - } -} - -struct ResultDisplayer { - result: Result, -} - -impl Display for ResultDisplayer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.result { - Ok(result) => match &result.data { - OutputData::AffectedRows(rows) => { - write!(f, "Affected Rows: {rows}") - } - OutputData::RecordBatches(recordbatches) => { - let pretty = recordbatches.pretty_print().map_err(|e| e.to_string()); - match pretty { - Ok(s) => write!(f, "{s}"), - Err(e) => { - write!(f, "Failed to pretty format {recordbatches:?}, error: {e}") - } - } - } - OutputData::Stream(_) => unreachable!(), - }, - Err(e) => { - let status_code = e.status_code(); - let root_cause = e.output_msg(); - write!( - f, - "Error: {}({status_code}), {root_cause}", - status_code as u32 - ) - } - } - } -} - -struct PostgresqlFormatter { - pub rows: Vec, -} - -impl Display for PostgresqlFormatter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.rows.is_empty() { - return f.write_fmt(format_args!("(Empty response)")); - } - - // create schema - let schema = match &self.rows[0] { - PgRow::CommandComplete(affected_rows) => { - write!( - f, - "{}", - ResultDisplayer { - result: Ok(Output::new_with_affected_rows(*affected_rows as usize)), - } - )?; - return Ok(()); - } - PgRow::RowDescription(desc) => Arc::new(Schema::new( - desc.iter() - .map(|column| { - ColumnSchema::new(column.name(), ConcreteDataType::string_datatype(), true) - }) - .collect(), - )), - _ => unreachable!(), - }; - if schema.num_columns() == 0 { - return Ok(()); - } - - // convert to string vectors - let mut columns: Vec = (0..schema.num_columns()) - .map(|_| StringVectorBuilder::with_capacity(schema.num_columns())) - .collect(); - for row in self.rows.iter().skip(1) { - if let PgRow::Row(row) = row { - for (i, column) in columns.iter_mut().enumerate().take(schema.num_columns()) { - column.push(row.get(i)); - } - } - } - let columns: Vec = columns - .into_iter() - .map(|mut col| Arc::new(col.finish()) as VectorRef) - .collect(); - - // construct recordbatch - let recordbatches = RecordBatches::try_from_columns(schema, columns) - .expect("Failed to construct recordbatches from columns. Please check the schema."); - let result_displayer = ResultDisplayer { - result: Ok(Output::new_with_record_batches(recordbatches)), - }; - write!(f, "{}", result_displayer)?; - - Ok(()) - } -} - -struct MysqlFormatter { - pub rows: Vec, -} - -impl Display for MysqlFormatter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.rows.is_empty() { - return f.write_fmt(format_args!("(Empty response)")); - } - // create schema - let head_column = &self.rows[0]; - let head_binding = head_column.columns(); - let names = head_binding - .iter() - .map(|column| column.name_str()) - .collect::>>(); - let schema = Arc::new(Schema::new( - names - .iter() - .map(|name| { - ColumnSchema::new(name.to_string(), ConcreteDataType::string_datatype(), false) - }) - .collect(), - )); - - // convert to string vectors - let mut columns: Vec = (0..schema.num_columns()) - .map(|_| StringVectorBuilder::with_capacity(schema.num_columns())) - .collect(); - for row in self.rows.iter() { - for (i, name) in names.iter().enumerate() { - columns[i].push(row.get::(name).as_deref()); - } - } - let columns: Vec = columns - .into_iter() - .map(|mut col| Arc::new(col.finish()) as VectorRef) - .collect(); - - // construct recordbatch - let recordbatches = RecordBatches::try_from_columns(schema, columns) - .expect("Failed to construct recordbatches from columns. Please check the schema."); - let result_displayer = ResultDisplayer { - result: Ok(Output::new_with_record_batches(recordbatches)), - }; - write!(f, "{}", result_displayer)?; - - Ok(()) - } -} +pub mod bare; +pub mod kube; diff --git a/tests/runner/src/env/bare.rs b/tests/runner/src/env/bare.rs new file mode 100644 index 0000000000..0644151aaa --- /dev/null +++ b/tests/runner/src/env/bare.rs @@ -0,0 +1,751 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::fmt::Display; +use std::fs::OpenOptions; +use std::io; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::process::{Child, Command}; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use async_trait::async_trait; +use sqlness::{Database, EnvController, QueryContext}; +use tokio::sync::Mutex as TokioMutex; + +use crate::client::MultiProtocolClient; +use crate::cmd::bare::ServerAddr; +use crate::formatter::{ErrorFormatter, MysqlFormatter, OutputFormatter, PostgresqlFormatter}; +use crate::protocol_interceptor::{MYSQL, PROTOCOL_KEY}; +use crate::server_mode::ServerMode; +use crate::util; +use crate::util::{PROGRAM, get_workspace_root, maybe_pull_binary}; + +// standalone mode +const SERVER_MODE_STANDALONE_IDX: usize = 0; +// distributed mode +const SERVER_MODE_METASRV_IDX: usize = 0; +const SERVER_MODE_DATANODE_START_IDX: usize = 1; +const SERVER_MODE_FRONTEND_IDX: usize = 4; +const SERVER_MODE_FLOWNODE_IDX: usize = 5; + +#[derive(Clone)] +pub enum WalConfig { + RaftEngine, + Kafka { + /// Indicates whether the runner needs to start a kafka cluster + /// (it might be available in the external system environment). + needs_kafka_cluster: bool, + broker_endpoints: Vec, + }, +} + +#[derive(Debug, Clone)] +pub(crate) enum ServiceProvider { + Create, + External(String), +} + +impl From<&str> for ServiceProvider { + fn from(value: &str) -> Self { + if value.is_empty() { + Self::Create + } else { + Self::External(value.to_string()) + } + } +} + +#[derive(Clone)] +pub struct StoreConfig { + pub store_addrs: Vec, + pub setup_etcd: bool, + pub(crate) setup_pg: Option, + pub(crate) setup_mysql: Option, +} + +#[derive(Clone)] +pub struct Env { + sqlness_home: PathBuf, + server_addrs: ServerAddr, + wal: WalConfig, + + /// The path to the directory that contains the pre-built GreptimeDB binary. + /// When running in CI, this is expected to be set. + /// If not set, this runner will build the GreptimeDB binary itself when needed, and set this field by then. + bins_dir: Arc>>, + /// The path to the directory that contains the old pre-built GreptimeDB binaries. + versioned_bins_dirs: Arc>>, + /// Pull different versions of GreptimeDB on need. + pull_version_on_need: bool, + /// Store address for metasrv metadata + store_config: StoreConfig, + /// Extra command line arguments when starting GreptimeDB binaries. + extra_args: Vec, +} + +#[async_trait] +impl EnvController for Env { + type DB = GreptimeDB; + + async fn start(&self, mode: &str, id: usize, _config: Option<&Path>) -> Self::DB { + if self.server_addrs.server_addr.is_some() && id > 0 { + panic!("Parallel test mode is not supported when server address is already set."); + } + + unsafe { + std::env::set_var("SQLNESS_HOME", self.sqlness_home.display().to_string()); + } + match mode { + "standalone" => self.start_standalone(id).await, + "distributed" => self.start_distributed(id).await, + _ => panic!("Unexpected mode: {mode}"), + } + } + + /// Stop one [`Database`]. + async fn stop(&self, _mode: &str, mut database: Self::DB) { + database.stop(); + } +} + +impl Env { + pub fn new( + data_home: PathBuf, + server_addrs: ServerAddr, + wal: WalConfig, + pull_version_on_need: bool, + bins_dir: Option, + store_config: StoreConfig, + extra_args: Vec, + ) -> Self { + Self { + sqlness_home: data_home, + server_addrs, + wal, + pull_version_on_need, + bins_dir: Arc::new(Mutex::new(bins_dir.clone())), + versioned_bins_dirs: Arc::new(Mutex::new(HashMap::from_iter([( + "latest".to_string(), + bins_dir.clone().unwrap_or(util::get_binary_dir("debug")), + )]))), + store_config, + extra_args, + } + } + + async fn start_standalone(&self, id: usize) -> GreptimeDB { + println!("Starting standalone instance id: {id}"); + + if self.server_addrs.server_addr.is_some() { + self.connect_db(&self.server_addrs, id).await + } else { + self.build_db(); + self.setup_wal(); + let mut db_ctx = GreptimeDBContext::new(self.wal.clone(), self.store_config.clone()); + + let server_mode = ServerMode::random_standalone(); + db_ctx.set_server_mode(server_mode.clone(), SERVER_MODE_STANDALONE_IDX); + let server_addr = server_mode.server_addr().unwrap(); + let server_process = self.start_server(server_mode, &db_ctx, id, true).await; + + let mut greptimedb = self.connect_db(&server_addr, id).await; + greptimedb.server_processes = Some(Arc::new(Mutex::new(vec![server_process]))); + greptimedb.is_standalone = true; + greptimedb.ctx = db_ctx; + + greptimedb + } + } + + async fn start_distributed(&self, id: usize) -> GreptimeDB { + if self.server_addrs.server_addr.is_some() { + self.connect_db(&self.server_addrs, id).await + } else { + self.build_db(); + self.setup_wal(); + self.setup_etcd(); + self.setup_pg(); + self.setup_mysql().await; + let mut db_ctx = GreptimeDBContext::new(self.wal.clone(), self.store_config.clone()); + + // start a distributed GreptimeDB + let meta_server_mode = ServerMode::random_metasrv(); + let metasrv_port = match &meta_server_mode { + ServerMode::Metasrv { + rpc_server_addr, .. + } => rpc_server_addr + .split(':') + .nth(1) + .unwrap() + .parse::() + .unwrap(), + _ => panic!( + "metasrv mode not set, maybe running in remote mode which doesn't support restart?" + ), + }; + db_ctx.set_server_mode(meta_server_mode.clone(), SERVER_MODE_METASRV_IDX); + let meta_server = self.start_server(meta_server_mode, &db_ctx, id, true).await; + + let datanode_1_mode = ServerMode::random_datanode(metasrv_port, 0); + db_ctx.set_server_mode(datanode_1_mode.clone(), SERVER_MODE_DATANODE_START_IDX); + let datanode_1 = self.start_server(datanode_1_mode, &db_ctx, id, true).await; + let datanode_2_mode = ServerMode::random_datanode(metasrv_port, 1); + db_ctx.set_server_mode(datanode_2_mode.clone(), SERVER_MODE_DATANODE_START_IDX + 1); + let datanode_2 = self.start_server(datanode_2_mode, &db_ctx, id, true).await; + let datanode_3_mode = ServerMode::random_datanode(metasrv_port, 2); + db_ctx.set_server_mode(datanode_3_mode.clone(), SERVER_MODE_DATANODE_START_IDX + 2); + let datanode_3 = self.start_server(datanode_3_mode, &db_ctx, id, true).await; + + let frontend_mode = ServerMode::random_frontend(metasrv_port); + let server_addr = frontend_mode.server_addr().unwrap(); + db_ctx.set_server_mode(frontend_mode.clone(), SERVER_MODE_FRONTEND_IDX); + let frontend = self.start_server(frontend_mode, &db_ctx, id, true).await; + + let flownode_mode = ServerMode::random_flownode(metasrv_port, 0); + db_ctx.set_server_mode(flownode_mode.clone(), SERVER_MODE_FLOWNODE_IDX); + let flownode = self.start_server(flownode_mode, &db_ctx, id, true).await; + + let mut greptimedb = self.connect_db(&server_addr, id).await; + + greptimedb.metasrv_process = Some(meta_server).into(); + greptimedb.server_processes = Some(Arc::new(Mutex::new(vec![ + datanode_1, datanode_2, datanode_3, + ]))); + greptimedb.frontend_process = Some(frontend).into(); + greptimedb.flownode_process = Some(flownode).into(); + greptimedb.is_standalone = false; + greptimedb.ctx = db_ctx; + + greptimedb + } + } + + async fn connect_db(&self, server_addr: &ServerAddr, id: usize) -> GreptimeDB { + let grpc_server_addr = server_addr.server_addr.as_ref().unwrap(); + let pg_server_addr = server_addr.pg_server_addr.as_ref().unwrap(); + let mysql_server_addr = server_addr.mysql_server_addr.as_ref().unwrap(); + + let client = + MultiProtocolClient::connect(grpc_server_addr, pg_server_addr, mysql_server_addr).await; + GreptimeDB { + client: TokioMutex::new(client), + server_processes: None, + metasrv_process: None.into(), + frontend_process: None.into(), + flownode_process: None.into(), + ctx: GreptimeDBContext { + time: 0, + datanode_id: Default::default(), + wal: self.wal.clone(), + store_config: self.store_config.clone(), + server_modes: Vec::new(), + }, + is_standalone: false, + env: self.clone(), + id, + } + } + + fn stop_server(process: &mut Child) { + let _ = process.kill(); + let _ = process.wait(); + } + + async fn start_server( + &self, + mode: ServerMode, + db_ctx: &GreptimeDBContext, + id: usize, + truncate_log: bool, + ) -> Child { + let log_file_name = match mode { + ServerMode::Datanode { node_id, .. } => { + db_ctx.incr_datanode_id(); + format!("greptime-{}-sqlness-datanode-{}.log", id, node_id) + } + ServerMode::Flownode { .. } => format!("greptime-{}-sqlness-flownode.log", id), + ServerMode::Frontend { .. } => format!("greptime-{}-sqlness-frontend.log", id), + ServerMode::Metasrv { .. } => format!("greptime-{}-sqlness-metasrv.log", id), + ServerMode::Standalone { .. } => format!("greptime-{}-sqlness-standalone.log", id), + }; + let stdout_file_name = self.sqlness_home.join(log_file_name).display().to_string(); + + println!("DB instance {id} log file at {stdout_file_name}"); + + let stdout_file = OpenOptions::new() + .create(true) + .write(true) + .truncate(truncate_log) + .append(!truncate_log) + .open(stdout_file_name) + .unwrap(); + + let args = mode.get_args(&self.sqlness_home, self, db_ctx, id); + let check_ip_addrs = mode.check_addrs(); + + for check_ip_addr in &check_ip_addrs { + if util::check_port(check_ip_addr.parse().unwrap(), Duration::from_secs(1)).await { + panic!( + "Port {check_ip_addr} is already in use, please check and retry.", + check_ip_addr = check_ip_addr + ); + } + } + + let program = PROGRAM; + + let bins_dir = self.bins_dir.lock().unwrap().clone().expect( + "GreptimeDB binary is not available. Please pass in the path to the directory that contains the pre-built GreptimeDB binary. Or you may call `self.build_db()` beforehand.", + ); + + let abs_bins_dir = bins_dir + .canonicalize() + .expect("Failed to canonicalize bins_dir"); + + let mut process = Command::new(abs_bins_dir.join(program)) + .current_dir(bins_dir.clone()) + .env("TZ", "UTC") + .args(args) + .stdout(stdout_file) + .spawn() + .unwrap_or_else(|error| { + panic!( + "Failed to start the DB with subcommand {}, Error: {error}, path: {:?}", + mode.name(), + bins_dir.join(program) + ); + }); + + for check_ip_addr in &check_ip_addrs { + if !util::check_port(check_ip_addr.parse().unwrap(), Duration::from_secs(10)).await { + Env::stop_server(&mut process); + panic!("{} doesn't up in 10 seconds, quit.", mode.name()) + } + } + + process + } + + /// stop and restart the server process + async fn restart_server(&self, db: &GreptimeDB, is_full_restart: bool) { + { + if let Some(server_process) = db.server_processes.clone() { + let mut server_processes = server_process.lock().unwrap(); + for server_process in server_processes.iter_mut() { + Env::stop_server(server_process); + } + } + + if is_full_restart { + if let Some(mut metasrv_process) = + db.metasrv_process.lock().expect("poisoned lock").take() + { + Env::stop_server(&mut metasrv_process); + } + if let Some(mut frontend_process) = + db.frontend_process.lock().expect("poisoned lock").take() + { + Env::stop_server(&mut frontend_process); + } + } + + if let Some(mut flownode_process) = + db.flownode_process.lock().expect("poisoned lock").take() + { + Env::stop_server(&mut flownode_process); + } + } + + // check if the server is distributed or standalone + let new_server_processes = if db.is_standalone { + let server_mode = db + .ctx + .get_server_mode(SERVER_MODE_STANDALONE_IDX) + .cloned() + .unwrap(); + let server_addr = server_mode.server_addr().unwrap(); + let new_server_process = self.start_server(server_mode, &db.ctx, db.id, false).await; + + let mut client = db.client.lock().await; + client + .reconnect_mysql_client(&server_addr.mysql_server_addr.unwrap()) + .await; + client + .reconnect_pg_client(&server_addr.pg_server_addr.unwrap()) + .await; + vec![new_server_process] + } else { + db.ctx.reset_datanode_id(); + if is_full_restart { + let metasrv_mode = db + .ctx + .get_server_mode(SERVER_MODE_METASRV_IDX) + .cloned() + .unwrap(); + let metasrv = self.start_server(metasrv_mode, &db.ctx, db.id, false).await; + db.metasrv_process + .lock() + .expect("lock poisoned") + .replace(metasrv); + + // wait for metasrv to start + // since it seems older version of db might take longer to complete election + tokio::time::sleep(Duration::from_secs(5)).await; + } + + let mut processes = vec![]; + for i in 0..3 { + let datanode_mode = db + .ctx + .get_server_mode(SERVER_MODE_DATANODE_START_IDX + i) + .cloned() + .unwrap(); + let new_server_process = self + .start_server(datanode_mode, &db.ctx, db.id, false) + .await; + processes.push(new_server_process); + } + + if is_full_restart { + let frontend_mode = db + .ctx + .get_server_mode(SERVER_MODE_FRONTEND_IDX) + .cloned() + .unwrap(); + let frontend = self + .start_server(frontend_mode, &db.ctx, db.id, false) + .await; + db.frontend_process + .lock() + .expect("lock poisoned") + .replace(frontend); + } + + let flownode_mode = db + .ctx + .get_server_mode(SERVER_MODE_FLOWNODE_IDX) + .cloned() + .unwrap(); + let flownode = self + .start_server(flownode_mode, &db.ctx, db.id, false) + .await; + db.flownode_process + .lock() + .expect("lock poisoned") + .replace(flownode); + + processes + }; + + if let Some(server_processes) = db.server_processes.clone() { + let mut server_processes = server_processes.lock().unwrap(); + *server_processes = new_server_processes; + } + } + + /// Setup kafka wal cluster if needed. The counterpart is in [GreptimeDB::stop]. + fn setup_wal(&self) { + if matches!(self.wal, WalConfig::Kafka { needs_kafka_cluster, .. } if needs_kafka_cluster) { + util::setup_wal(); + } + } + + /// Setup etcd if needed. + fn setup_etcd(&self) { + if self.store_config.setup_etcd { + let client_ports = self + .store_config + .store_addrs + .iter() + .map(|s| s.split(':').nth(1).unwrap().parse::().unwrap()) + .collect::>(); + util::setup_etcd(client_ports, None, None); + } + } + + /// Setup PostgreSql if needed. + fn setup_pg(&self) { + if matches!(self.store_config.setup_pg, Some(ServiceProvider::Create)) { + let client_ports = self + .store_config + .store_addrs + .iter() + .map(|s| s.split(':').nth(1).unwrap().parse::().unwrap()) + .collect::>(); + let client_port = client_ports.first().unwrap_or(&5432); + util::setup_pg(*client_port, None); + } + } + + /// Setup MySql if needed. + async fn setup_mysql(&self) { + if matches!(self.store_config.setup_mysql, Some(ServiceProvider::Create)) { + let client_ports = self + .store_config + .store_addrs + .iter() + .map(|s| s.split(':').nth(1).unwrap().parse::().unwrap()) + .collect::>(); + let client_port = client_ports.first().unwrap_or(&3306); + util::setup_mysql(*client_port, None); + + // Docker of MySQL starts slowly, so we need to wait for a while + tokio::time::sleep(Duration::from_secs(10)).await; + } + } + + /// Build the DB with `cargo build --bin greptime` + fn build_db(&self) { + if self.bins_dir.lock().unwrap().is_some() { + return; + } + + println!("Going to build the DB..."); + let output = Command::new("cargo") + .current_dir(util::get_workspace_root()) + .args([ + "build", + "--bin", + "greptime", + "--features", + "pg_kvbackend,mysql_kvbackend", + ]) + .output() + .expect("Failed to start GreptimeDB"); + if !output.status.success() { + println!("Failed to build GreptimeDB, {}", output.status); + println!("Cargo build stdout:"); + io::stdout().write_all(&output.stdout).unwrap(); + println!("Cargo build stderr:"); + io::stderr().write_all(&output.stderr).unwrap(); + panic!(); + } + + let _ = self + .bins_dir + .lock() + .unwrap() + .insert(util::get_binary_dir("debug")); + } + + pub(crate) fn extra_args(&self) -> &Vec { + &self.extra_args + } +} + +pub struct GreptimeDB { + server_processes: Option>>>, + metasrv_process: Mutex>, + frontend_process: Mutex>, + flownode_process: Mutex>, + client: TokioMutex, + ctx: GreptimeDBContext, + is_standalone: bool, + env: Env, + id: usize, +} + +impl GreptimeDB { + async fn postgres_query(&self, _ctx: QueryContext, query: String) -> Box { + let mut client = self.client.lock().await; + + match client.postgres_query(&query).await { + Ok(rows) => Box::new(PostgresqlFormatter::from(rows)), + Err(e) => Box::new(e), + } + } + + async fn mysql_query(&self, _ctx: QueryContext, query: String) -> Box { + let mut client = self.client.lock().await; + + match client.mysql_query(&query).await { + Ok(res) => Box::new(MysqlFormatter::from(res)), + Err(e) => Box::new(e), + } + } + + async fn grpc_query(&self, _ctx: QueryContext, query: String) -> Box { + let mut client = self.client.lock().await; + + match client.grpc_query(&query).await { + Ok(rows) => Box::new(OutputFormatter::from(rows)), + Err(e) => Box::new(ErrorFormatter::from(e)), + } + } +} + +#[async_trait] +impl Database for GreptimeDB { + async fn query(&self, ctx: QueryContext, query: String) -> Box { + if ctx.context.contains_key("restart") && self.env.server_addrs.server_addr.is_none() { + self.env.restart_server(self, false).await; + } else if let Some(version) = ctx.context.get("version") { + let version_bin_dir = self + .env + .versioned_bins_dirs + .lock() + .expect("lock poison") + .get(version.as_str()) + .cloned(); + + match version_bin_dir { + Some(path) if path.clone().join(PROGRAM).is_file() => { + // use version in versioned_bins_dirs + *self.env.bins_dir.lock().unwrap() = Some(path.clone()); + } + _ => { + // use version in dir files + maybe_pull_binary(version, self.env.pull_version_on_need).await; + let root = get_workspace_root(); + let new_path = PathBuf::from_iter([&root, version]); + *self.env.bins_dir.lock().unwrap() = Some(new_path); + } + } + + self.env.restart_server(self, true).await; + // sleep for a while to wait for the server to fully boot up + tokio::time::sleep(Duration::from_secs(5)).await; + } + + if let Some(protocol) = ctx.context.get(PROTOCOL_KEY) { + // protocol is bound to be either "mysql" or "postgres" + if protocol == MYSQL { + self.mysql_query(ctx, query).await + } else { + self.postgres_query(ctx, query).await + } + } else { + self.grpc_query(ctx, query).await + } + } +} + +impl GreptimeDB { + fn stop(&mut self) { + if let Some(server_processes) = self.server_processes.clone() { + let mut server_processes = server_processes.lock().unwrap(); + for mut server_process in server_processes.drain(..) { + Env::stop_server(&mut server_process); + println!( + "Standalone or Datanode (pid = {}) is stopped", + server_process.id() + ); + } + } + if let Some(mut metasrv) = self + .metasrv_process + .lock() + .expect("someone else panic when holding lock") + .take() + { + Env::stop_server(&mut metasrv); + println!("Metasrv (pid = {}) is stopped", metasrv.id()); + } + if let Some(mut frontend) = self + .frontend_process + .lock() + .expect("someone else panic when holding lock") + .take() + { + Env::stop_server(&mut frontend); + println!("Frontend (pid = {}) is stopped", frontend.id()); + } + if let Some(mut flownode) = self + .flownode_process + .lock() + .expect("someone else panic when holding lock") + .take() + { + Env::stop_server(&mut flownode); + println!("Flownode (pid = {}) is stopped", flownode.id()); + } + if matches!(self.ctx.wal, WalConfig::Kafka { needs_kafka_cluster, .. } if needs_kafka_cluster) + { + util::teardown_wal(); + } + } +} + +impl Drop for GreptimeDB { + fn drop(&mut self) { + if self.env.server_addrs.server_addr.is_none() { + self.stop(); + } + } +} + +pub struct GreptimeDBContext { + /// Start time in millisecond + time: i64, + datanode_id: AtomicU32, + wal: WalConfig, + store_config: StoreConfig, + server_modes: Vec, +} + +impl GreptimeDBContext { + pub fn new(wal: WalConfig, store_config: StoreConfig) -> Self { + Self { + time: common_time::util::current_time_millis(), + datanode_id: AtomicU32::new(0), + wal, + store_config, + server_modes: Vec::new(), + } + } + + pub(crate) fn time(&self) -> i64 { + self.time + } + + pub fn is_raft_engine(&self) -> bool { + matches!(self.wal, WalConfig::RaftEngine) + } + + pub fn kafka_wal_broker_endpoints(&self) -> String { + match &self.wal { + WalConfig::RaftEngine => String::new(), + WalConfig::Kafka { + broker_endpoints, .. + } => serde_json::to_string(&broker_endpoints).unwrap(), + } + } + + fn incr_datanode_id(&self) { + let _ = self.datanode_id.fetch_add(1, Ordering::Relaxed); + } + + fn reset_datanode_id(&self) { + self.datanode_id.store(0, Ordering::Relaxed); + } + + pub(crate) fn store_config(&self) -> StoreConfig { + self.store_config.clone() + } + + fn set_server_mode(&mut self, mode: ServerMode, idx: usize) { + if idx >= self.server_modes.len() { + self.server_modes.resize(idx + 1, mode.clone()); + } + self.server_modes[idx] = mode; + } + + fn get_server_mode(&self, idx: usize) -> Option<&ServerMode> { + self.server_modes.get(idx) + } +} diff --git a/tests/runner/src/env/kube.rs b/tests/runner/src/env/kube.rs new file mode 100644 index 0000000000..91fb747102 --- /dev/null +++ b/tests/runner/src/env/kube.rs @@ -0,0 +1,204 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; +use std::path::Path; +use std::sync::Arc; + +use async_trait::async_trait; +use sqlness::{Database, EnvController, QueryContext}; +use tokio::process::Command; +use tokio::sync::Mutex; + +use crate::client::MultiProtocolClient; +use crate::formatter::{ErrorFormatter, MysqlFormatter, OutputFormatter, PostgresqlFormatter}; +use crate::protocol_interceptor::{MYSQL, PROTOCOL_KEY}; + +#[async_trait] +pub trait DatabaseManager: Send + Sync { + // Restarts the database. + async fn restart(&self, database: &GreptimeDB); +} + +#[async_trait] +impl DatabaseManager for () { + async fn restart(&self, _: &GreptimeDB) { + // Do nothing + } +} + +#[async_trait] +pub trait ResourcesManager: Send + Sync { + // Delete namespace. + async fn delete_namespace(&self); + + // Get namespace. + fn namespace(&self) -> &str; +} + +#[derive(Clone)] +pub struct Env { + /// Whether to delete the namespace on stop. + pub delete_namespace_on_stop: bool, + /// Address of the grpc server. + pub server_addr: String, + /// Address of the postgres server. + pub pg_server_addr: String, + /// Address of the mysql server. + pub mysql_server_addr: String, + /// The database manager. + pub database_manager: Arc, + /// The resources manager. + pub resources_manager: Arc, +} + +#[async_trait] +impl EnvController for Env { + type DB = GreptimeDB; + + async fn start(&self, mode: &str, id: usize, _config: Option<&Path>) -> Self::DB { + if id > 0 { + panic!("Parallel test mode is not supported in kube env"); + } + + match mode { + "standalone" | "distributed" => GreptimeDB { + client: Mutex::new( + MultiProtocolClient::connect( + &self.server_addr, + &self.pg_server_addr, + &self.mysql_server_addr, + ) + .await, + ), + database_manager: self.database_manager.clone(), + resources_manager: self.resources_manager.clone(), + delete_namespace_on_stop: self.delete_namespace_on_stop, + }, + + _ => panic!("Unexpected mode: {mode}"), + } + } + + async fn stop(&self, _mode: &str, database: Self::DB) { + database.stop().await; + } +} + +pub struct GreptimeDB { + pub client: Mutex, + pub delete_namespace_on_stop: bool, + pub database_manager: Arc, + pub resources_manager: Arc, +} + +impl GreptimeDB { + async fn postgres_query(&self, _ctx: QueryContext, query: String) -> Box { + let mut client = self.client.lock().await; + + match client.postgres_query(&query).await { + Ok(rows) => Box::new(PostgresqlFormatter::from(rows)), + Err(e) => Box::new(e), + } + } + + async fn mysql_query(&self, _ctx: QueryContext, query: String) -> Box { + let mut client = self.client.lock().await; + + match client.mysql_query(&query).await { + Ok(res) => Box::new(MysqlFormatter::from(res)), + Err(e) => Box::new(e), + } + } + + async fn grpc_query(&self, _ctx: QueryContext, query: String) -> Box { + let mut client = self.client.lock().await; + + match client.grpc_query(&query).await { + Ok(rows) => Box::new(OutputFormatter::from(rows)), + Err(e) => Box::new(ErrorFormatter::from(e)), + } + } +} + +#[async_trait] +impl Database for GreptimeDB { + async fn query(&self, ctx: QueryContext, query: String) -> Box { + if ctx.context.contains_key("restart") { + self.database_manager.restart(self).await + } + + if let Some(protocol) = ctx.context.get(PROTOCOL_KEY) { + // protocol is bound to be either "mysql" or "postgres" + if protocol == MYSQL { + self.mysql_query(ctx, query).await + } else { + self.postgres_query(ctx, query).await + } + } else { + self.grpc_query(ctx, query).await + } + } +} + +impl GreptimeDB { + async fn stop(&self) { + if self.delete_namespace_on_stop { + self.resources_manager.delete_namespace().await; + println!("Deleted namespace({})", self.resources_manager.namespace()); + } else { + println!( + "Namespace({}) is not deleted", + self.resources_manager.namespace() + ); + } + } +} + +pub struct NaiveResourcesManager { + namespace: String, +} + +impl NaiveResourcesManager { + pub fn new(namespace: String) -> Self { + Self { namespace } + } +} + +#[async_trait] +impl ResourcesManager for NaiveResourcesManager { + async fn delete_namespace(&self) { + let output = Command::new("kubectl") + .arg("delete") + .arg("namespace") + .arg(&self.namespace) + .output() + .await + .unwrap_or_else(|e| { + panic!( + "Failed to execute kubectl delete namespace({}): {}", + self.namespace, e + ) + }); + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + panic!("Failed to delete namespace({}): {}", self.namespace, stderr); + } + } + + fn namespace(&self) -> &str { + &self.namespace + } +} diff --git a/tests/runner/src/formatter.rs b/tests/runner/src/formatter.rs new file mode 100644 index 0000000000..a379ad3244 --- /dev/null +++ b/tests/runner/src/formatter.rs @@ -0,0 +1,216 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::fmt::Display; +use std::sync::Arc; + +use client::{Output, OutputData, RecordBatches}; +use common_error::ext::ErrorExt; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::schema::{ColumnSchema, Schema}; +use datatypes::vectors::{StringVectorBuilder, VectorRef}; +use mysql::Row as MySqlRow; +use tokio_postgres::SimpleQueryMessage as PgRow; + +use crate::client::MysqlSqlResult; + +/// A formatter for errors. +pub struct ErrorFormatter(E); + +impl From for ErrorFormatter { + fn from(error: E) -> Self { + ErrorFormatter(error) + } +} + +impl Display for ErrorFormatter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let status_code = self.0.status_code(); + let root_cause = self.0.output_msg(); + write!( + f, + "Error: {}({status_code}), {root_cause}", + status_code as u32 + ) + } +} + +/// A formatter for [`Output`]. +pub struct OutputFormatter(Output); + +impl From for OutputFormatter { + fn from(output: Output) -> Self { + OutputFormatter(output) + } +} + +impl Display for OutputFormatter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0.data { + OutputData::AffectedRows(rows) => { + write!(f, "Affected Rows: {rows}") + } + OutputData::RecordBatches(recordbatches) => { + let pretty = recordbatches.pretty_print().map_err(|e| e.to_string()); + match pretty { + Ok(s) => write!(f, "{s}"), + Err(e) => { + write!(f, "Failed to pretty format {recordbatches:?}, error: {e}") + } + } + } + OutputData::Stream(_) => unreachable!(), + } + } +} + +pub struct PostgresqlFormatter(Vec); + +impl From> for PostgresqlFormatter { + fn from(rows: Vec) -> Self { + PostgresqlFormatter(rows) + } +} + +impl Display for PostgresqlFormatter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.0.is_empty() { + return f.write_fmt(format_args!("(Empty response)")); + } + + if let PgRow::CommandComplete(affected_rows) = &self.0[0] { + return write!( + f, + "{}", + OutputFormatter(Output::new_with_affected_rows(*affected_rows as usize)) + ); + }; + + let Some(recordbatches) = build_recordbatches_from_postgres_rows(&self.0) else { + return Ok(()); + }; + write!( + f, + "{}", + OutputFormatter(Output::new_with_record_batches(recordbatches)) + ) + } +} + +fn build_recordbatches_from_postgres_rows(rows: &[PgRow]) -> Option { + // create schema + let schema = match &rows[0] { + PgRow::RowDescription(desc) => Arc::new(Schema::new( + desc.iter() + .map(|column| { + ColumnSchema::new(column.name(), ConcreteDataType::string_datatype(), true) + }) + .collect(), + )), + _ => unreachable!(), + }; + if schema.num_columns() == 0 { + return None; + } + + // convert to string vectors + let mut columns: Vec = (0..schema.num_columns()) + .map(|_| StringVectorBuilder::with_capacity(schema.num_columns())) + .collect(); + for row in rows.iter().skip(1) { + if let PgRow::Row(row) = row { + for (i, column) in columns.iter_mut().enumerate().take(schema.num_columns()) { + column.push(row.get(i)); + } + } + } + let columns: Vec = columns + .into_iter() + .map(|mut col| Arc::new(col.finish()) as VectorRef) + .collect(); + + // construct recordbatch + let recordbatches = RecordBatches::try_from_columns(schema, columns) + .expect("Failed to construct recordbatches from columns. Please check the schema."); + Some(recordbatches) +} + +/// A formatter for [`MysqlSqlResult`]. +pub struct MysqlFormatter(MysqlSqlResult); + +impl From for MysqlFormatter { + fn from(result: MysqlSqlResult) -> Self { + MysqlFormatter(result) + } +} + +impl Display for MysqlFormatter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0 { + MysqlSqlResult::AffectedRows(rows) => { + write!(f, "affected_rows: {rows}") + } + MysqlSqlResult::Rows(rows) => { + if rows.is_empty() { + return f.write_fmt(format_args!("(Empty response)")); + } + + let recordbatches = build_recordbatches_from_mysql_rows(rows); + write!( + f, + "{}", + OutputFormatter(Output::new_with_record_batches(recordbatches)) + ) + } + } + } +} + +pub fn build_recordbatches_from_mysql_rows(rows: &[MySqlRow]) -> RecordBatches { + // create schema + let head_column = &rows[0]; + let head_binding = head_column.columns(); + let names = head_binding + .iter() + .map(|column| column.name_str()) + .collect::>>(); + let schema = Arc::new(Schema::new( + names + .iter() + .map(|name| { + ColumnSchema::new(name.to_string(), ConcreteDataType::string_datatype(), false) + }) + .collect(), + )); + + // convert to string vectors + let mut columns: Vec = (0..schema.num_columns()) + .map(|_| StringVectorBuilder::with_capacity(schema.num_columns())) + .collect(); + for row in rows.iter() { + for (i, name) in names.iter().enumerate() { + columns[i].push(row.get::(name).as_deref()); + } + } + let columns: Vec = columns + .into_iter() + .map(|mut col| Arc::new(col.finish()) as VectorRef) + .collect(); + + // construct recordbatch + RecordBatches::try_from_columns(schema, columns) + .expect("Failed to construct recordbatches from columns. Please check the schema.") +} diff --git a/tests/runner/src/main.rs b/tests/runner/src/main.rs index c2392b3599..c6cf5dd4b5 100644 --- a/tests/runner/src/main.rs +++ b/tests/runner/src/main.rs @@ -14,214 +14,24 @@ #![allow(clippy::print_stdout)] -use std::path::PathBuf; -use std::sync::Arc; +use clap::Parser; -use clap::{Parser, ValueEnum}; -use env::{Env, WalConfig}; -use sqlness::interceptor::Registry; -use sqlness::{ConfigBuilder, Runner}; - -use crate::env::{ServiceProvider, StoreConfig}; +use crate::cmd::{Command, SubCommand}; +pub mod client; +mod cmd; mod env; -mod protocol_interceptor; +pub mod formatter; +pub mod protocol_interceptor; mod server_mode; mod util; -#[derive(ValueEnum, Debug, Clone)] -#[clap(rename_all = "snake_case")] -enum Wal { - RaftEngine, - Kafka, -} - -// add a group to ensure that all server addresses are set together -#[derive(clap::Args, Debug, Clone, Default)] -#[group(multiple = true, requires_all=["server_addr", "pg_server_addr", "mysql_server_addr"])] -struct ServerAddr { - /// Address of the grpc server. - #[clap(short, long)] - server_addr: Option, - - /// Address of the postgres server. Must be set if server_addr is set. - #[clap(short, long, requires = "server_addr")] - pg_server_addr: Option, - - /// Address of the mysql server. Must be set if server_addr is set. - #[clap(short, long, requires = "server_addr")] - mysql_server_addr: Option, -} - -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -/// SQL Harness for GrepTimeDB -struct Args { - /// Directory of test cases - #[clap(short, long)] - case_dir: Option, - - /// Fail this run as soon as one case fails if true - #[arg(short, long, default_value = "false")] - fail_fast: bool, - - /// Environment Configuration File - #[clap(short, long, default_value = "config.toml")] - env_config_file: String, - - /// Name of test cases to run. Accept as a regexp. - #[clap(short, long, default_value = ".*")] - test_filter: String, - - /// Addresses of the server. - #[command(flatten)] - server_addr: ServerAddr, - - /// The type of Wal. - #[clap(short, long, default_value = "raft_engine")] - wal: Wal, - - /// The kafka wal broker endpoints. This config will suppress sqlness runner - /// from starting a kafka cluster, and use the given endpoint as kafka backend. - #[clap(short, long)] - kafka_wal_broker_endpoints: Option, - - /// The path to the directory where GreptimeDB's binaries resides. - /// If not set, sqlness will build GreptimeDB on the fly. - #[clap(long)] - bins_dir: Option, - - /// Preserve persistent state in the temporary directory. - /// This may affect future test runs. - #[clap(long)] - preserve_state: bool, - - /// Pull Different versions of GreptimeDB on need. - #[clap(long, default_value = "true")] - pull_version_on_need: bool, - - /// The store addresses for metadata, if empty, will use memory store. - #[clap(long)] - store_addrs: Vec, - - /// Whether to setup etcd, by default it is false. - #[clap(long, default_value = "false")] - setup_etcd: bool, - - /// Whether to setup pg, by default it is false. - #[clap(long, default_missing_value = "", num_args(0..=1))] - setup_pg: Option, - - /// Whether to setup mysql, by default it is false. - #[clap(long, default_missing_value = "", num_args(0..=1))] - setup_mysql: Option, - - /// The number of jobs to run in parallel. Default to half of the cores. - #[clap(short, long, default_value = "0")] - jobs: usize, - - /// Extra command line arguments when starting GreptimeDB binaries. - #[clap(long)] - extra_args: Vec, -} - #[tokio::main] async fn main() { - let mut args = Args::parse(); + let cmd = Command::parse(); - let temp_dir = tempfile::Builder::new() - .prefix("sqlness") - .tempdir() - .unwrap(); - let sqlness_home = temp_dir.keep(); - - let mut interceptor_registry: Registry = Default::default(); - interceptor_registry.register( - protocol_interceptor::PREFIX, - Arc::new(protocol_interceptor::ProtocolInterceptorFactory), - ); - - if let Some(d) = &args.case_dir - && !d.is_dir() - { - panic!("{} is not a directory", d.display()); - } - if args.jobs == 0 { - args.jobs = num_cpus::get() / 2; - } - - // normalize parallelism to 1 if any of the following conditions are met: - // Note: parallelism in pg and mysql is possible, but need configuration. - if args.server_addr.server_addr.is_some() - || args.setup_etcd - || args.setup_pg.is_some() - || args.setup_mysql.is_some() - || args.kafka_wal_broker_endpoints.is_some() - || args.test_filter != ".*" - { - args.jobs = 1; - println!( - "Normalizing parallelism to 1 due to server addresses, etcd/pg/mysql setup, or test filter usage" - ); - } - - let config = ConfigBuilder::default() - .case_dir(util::get_case_dir(args.case_dir)) - .fail_fast(args.fail_fast) - .test_filter(args.test_filter) - .follow_links(true) - .env_config_file(args.env_config_file) - .interceptor_registry(interceptor_registry) - .parallelism(args.jobs) - .build() - .unwrap(); - - let wal = match args.wal { - Wal::RaftEngine => WalConfig::RaftEngine, - Wal::Kafka => WalConfig::Kafka { - needs_kafka_cluster: args.kafka_wal_broker_endpoints.is_none(), - broker_endpoints: args - .kafka_wal_broker_endpoints - .map(|s| s.split(',').map(|s| s.to_string()).collect()) - // otherwise default to the same port in `kafka-cluster.yml` - .unwrap_or(vec!["127.0.0.1:9092".to_string()]), - }, - }; - - let store = StoreConfig { - store_addrs: args.store_addrs.clone(), - setup_etcd: args.setup_etcd, - setup_pg: args.setup_pg, - setup_mysql: args.setup_mysql, - }; - - let runner = Runner::new( - config, - Env::new( - sqlness_home.clone(), - args.server_addr.clone(), - wal, - args.pull_version_on_need, - args.bins_dir, - store, - args.extra_args, - ), - ); - match runner.run().await { - Ok(_) => println!("\x1b[32mAll sqlness tests passed!\x1b[0m"), - Err(e) => { - println!("\x1b[31mTest failed: {}\x1b[0m", e); - std::process::exit(1); - } - } - - // clean up and exit - if !args.preserve_state { - if args.setup_etcd { - println!("Stopping etcd"); - util::stop_rm_etcd(); - } - println!("Removing state in {:?}", sqlness_home); - tokio::fs::remove_dir_all(sqlness_home).await.unwrap(); + match cmd.subcmd { + SubCommand::Bare(cmd) => cmd.run().await, + SubCommand::Kube(cmd) => cmd.run().await, } } diff --git a/tests/runner/src/server_mode.rs b/tests/runner/src/server_mode.rs index 77f9154c36..2e4cefa2c6 100644 --- a/tests/runner/src/server_mode.rs +++ b/tests/runner/src/server_mode.rs @@ -19,8 +19,9 @@ use std::sync::{Mutex, OnceLock}; use serde::Serialize; use tinytemplate::TinyTemplate; -use crate::env::{Env, GreptimeDBContext, ServiceProvider}; -use crate::{ServerAddr, util}; +use crate::cmd::bare::ServerAddr; +use crate::env::bare::{Env, GreptimeDBContext, ServiceProvider}; +use crate::util; const DEFAULT_LOG_LEVEL: &str = "--log-level=debug,hyper=warn,tower=warn,datafusion=warn,reqwest=warn,sqlparser=warn,h2=info,opendal=info"; diff --git a/tests/runner/src/util.rs b/tests/runner/src/util.rs index 621b5856c3..6c00c8a9a6 100644 --- a/tests/runner/src/util.rs +++ b/tests/runner/src/util.rs @@ -533,3 +533,28 @@ pub fn get_random_port() -> u16 { .expect("Failed to get local address") .port() } + +/// Retry to execute the function until success or the maximum number of retries is reached. +pub async fn retry_with_backoff( + mut fut: F, + max_retry: usize, + init_backoff: Duration, +) -> Result +where + F: FnMut() -> Fut, + Fut: Future>, +{ + let mut backoff = init_backoff; + for attempt in 0..max_retry { + match fut().await { + Ok(res) => return Ok(res), + Err(err) if attempt + 1 == max_retry => return Err(err), + Err(_) => { + tokio::time::sleep(backoff).await; + backoff *= 2; + } + } + } + + unreachable!("loop should have returned before here") +}